From 3976d9515659a7ac5d8e30a2ce1376c63ca8af3f Mon Sep 17 00:00:00 2001 From: Ole Sasse Date: Thu, 9 Jun 2022 09:42:54 +0200 Subject: [PATCH] [SPARK-39259][SQL][3.1] Evaluate timestamps consistently in subqueries --- .../catalyst/optimizer/finishAnalysis.scala | 27 ++++---- .../spark/sql/catalyst/plans/QueryPlan.scala | 22 ++++++ .../optimizer/ComputeCurrentTimeSuite.scala | 67 ++++++++++++++----- 3 files changed, 84 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 1f2389176d1e0..2a5b6396abb26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.catalyst.optimizer -import scala.collection.mutable +import java.time.Instant import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils.instantToMicros import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -72,21 +73,19 @@ object RewriteNonCorrelatedExists extends Rule[LogicalPlan] { */ object ComputeCurrentTime extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - val currentDates = mutable.Map.empty[String, Literal] - val timeExpr = CurrentTimestamp() - val timestamp = timeExpr.eval(EmptyRow).asInstanceOf[Long] - val currentTime = Literal.create(timestamp, timeExpr.dataType) + val instant = Instant.now() + val currentTimestampMicros = instantToMicros(instant) + val currentTime = Literal.create(currentTimestampMicros, TimestampType) val timezone = Literal.create(SQLConf.get.sessionLocalTimeZone, StringType) - plan transformAllExpressions { - case currentDate @ CurrentDate(Some(timeZoneId)) => - currentDates.getOrElseUpdate(timeZoneId, { - Literal.create( - DateTimeUtils.microsToDays(timestamp, currentDate.zoneId), - DateType) - }) - case CurrentTimestamp() | Now() => currentTime - case CurrentTimeZone() => timezone + plan.transformDownWithSubqueries { + case subQuery => + subQuery.transformAllExpressions { + case cd: CurrentDate => + Literal.create(DateTimeUtils.microsToDays(currentTimestampMicros, cd.zoneId), DateType) + case CurrentTimestamp() | Now() => currentTime + case CurrentTimeZone() => timezone + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 864ca4f57483d..e92959d9d1ff9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -340,6 +340,28 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] subqueries ++ subqueries.flatMap(_.subqueriesAll) } + /** + * Returns a copy of this node where the given partial function has been recursively applied + * first to this node, then this node's subqueries and finally this node's children. + * When the partial function does not apply to a given node, it is left unchanged. + */ + def transformDownWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { + val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] { + override def isDefinedAt(x: PlanType): Boolean = true + + override def apply(plan: PlanType): PlanType = { + val transformed = f.applyOrElse[PlanType, PlanType](plan, identity) + transformed transformExpressionsDown { + case planExpression: PlanExpression[PlanType] => + val newPlan = planExpression.plan.transformDownWithSubqueries(f) + planExpression.withNewPlan(newPlan) + } + } + } + + transformDown(g) + } + /** * A variant of `collect`. This method not only apply the given function to all elements in this * plan, also considering all the plans in its (nested) subqueries diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala index 82d6757407b51..f175e9032eb64 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.catalyst.optimizer import java.time.ZoneId +import scala.concurrent.duration._ + import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, InSubquery, ListQuery, Literal, Now} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf @@ -41,11 +43,7 @@ class ComputeCurrentTimeSuite extends PlanTest { val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = (System.currentTimeMillis() + 1) * 1000 - val lits = new scala.collection.mutable.ArrayBuffer[Long] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[Long] - e - } + val lits = literals[Long](plan) assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) @@ -59,11 +57,7 @@ class ComputeCurrentTimeSuite extends PlanTest { val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = DateTimeUtils.currentDate(ZoneId.systemDefault()) - val lits = new scala.collection.mutable.ArrayBuffer[Int] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[Int] - e - } + val lits = literals[Int](plan) assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) @@ -73,12 +67,49 @@ class ComputeCurrentTimeSuite extends PlanTest { test("SPARK-33469: Add current_timezone function") { val in = Project(Seq(Alias(CurrentTimeZone(), "c")()), LocalRelation()) val plan = Optimize.execute(in.analyze).asInstanceOf[Project] - val lits = new scala.collection.mutable.ArrayBuffer[String] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[UTF8String].toString - e - } + val lits = literals[UTF8String](plan) assert(lits.size == 1) - assert(lits.head == SQLConf.get.sessionLocalTimeZone) + assert(lits.head == UTF8String.fromString(SQLConf.get.sessionLocalTimeZone)) + } + + test("analyzer should use equal timestamps across subqueries") { + val timestampInSubQuery = Project(Seq(Alias(Now(), "timestamp1")()), LocalRelation()) + val listSubQuery = ListQuery(timestampInSubQuery) + val valueSearchedInSubQuery = Seq(Alias(Now(), "timestamp2")()) + val inFilterWithSubQuery = InSubquery(valueSearchedInSubQuery, listSubQuery) + val input = Project(Nil, Filter(inFilterWithSubQuery, LocalRelation())) + + val plan = Optimize.execute(input.analyze).asInstanceOf[Project] + + val lits = literals[Long](plan) + assert(lits.size == 3) // transformDownWithSubqueries covers the inner timestamp twice + assert(lits.toSet.size == 1) + } + + test("analyzer should use consistent timestamps for different timestamp functions") { + val differentTimestamps = Seq( + Alias(CurrentTimestamp(), "currentTimestamp")(), + Alias(Now(), "now")() + ) + val input = Project(differentTimestamps, LocalRelation()) + + val plan = Optimize.execute(input).asInstanceOf[Project] + + val lits = literals[Long](plan) + assert(lits.size === differentTimestamps.size) + // there are timezones with a 30 or 45 minute offset + val offsetsFromQuarterHour = lits.map( _ % Duration(15, MINUTES).toMicros).toSet + assert(offsetsFromQuarterHour.size == 1) + } + + private def literals[T](plan: LogicalPlan): scala.collection.mutable.ArrayBuffer[T] = { + val literals = new scala.collection.mutable.ArrayBuffer[T] + plan.transformDownWithSubqueries { case subQuery => + subQuery.transformAllExpressions { case expression: Literal => + literals += expression.value.asInstanceOf[T] + expression + } + } + literals } }