From 0fe1c710a5d91249556ac767a08b6f723e29b7ba Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 11 Mar 2022 22:23:50 +0800 Subject: [PATCH] fix misleading function alias name for RuntimeReplaceable --- .../spark/sql/catalyst/optimizer/finishAnalysis.scala | 9 +++++++-- .../test/scala/org/apache/spark/sql/ExplainSuite.scala | 9 ++++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 7b896e2c9607c..ef9c4b9af40d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -40,9 +40,14 @@ import org.apache.spark.util.Utils * we use this to replace Every and Any with Min and Max respectively. */ object ReplaceExpressions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning( + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( _.containsAnyPattern(RUNTIME_REPLACEABLE)) { - case e: RuntimeReplaceable => e.replacement + case p => p.mapExpressions(replace) + } + + private def replace(e: Expression): Expression = e match { + case r: RuntimeReplaceable => replace(r.replacement) + case _ => e.mapChildren(replace) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index 3659f20fb6ec2..073b67e0472bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -106,7 +106,7 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite keywords = "InMemoryRelation", "StorageLevel(disk, memory, deserialized, 1 replicas)") } - test("optimized plan should show the rewritten aggregate expression") { + test("optimized plan should show the rewritten expression") { withTempView("test_agg") { sql( """ @@ -125,6 +125,13 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite "Aggregate [k#x], [k#x, every(v#x) AS every(v)#x, some(v#x) AS some(v)#x, " + "any(v#x) AS any(v)#x]") } + + withTable("t") { + sql("CREATE TABLE t(col TIMESTAMP) USING parquet") + val df = sql("SELECT date_part('month', col) FROM t") + checkKeywordsExistsInExplain(df, + "Project [month(cast(col#x as date)) AS date_part(month, col)#x]") + } } test("explain inline tables cross-joins") {