Skip to content

Commit 5c03e07

Browse files
author
Marcelo Vanzin
committed
Cleanup AdjustTimestamps.
Running the rule during resolution also allowed to do all the needed ajustments with a single rule (instead of needing a Hive-specific rule for InsertIntoHiveTable).
1 parent 1eaa045 commit 5c03e07

File tree

4 files changed

+86
-184
lines changed

4 files changed

+86
-184
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AdjustTimestamps.scala

Lines changed: 83 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -16,183 +16,113 @@
1616
*/
1717
package org.apache.spark.sql.execution.datasources
1818

19-
import org.apache.spark.sql.{AnalysisException, SparkSession}
20-
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
19+
import org.apache.spark.sql.{AnalysisException}
2120
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HiveTableRelation}
2221
import org.apache.spark.sql.catalyst.expressions._
23-
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
22+
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, Project}
2423
import org.apache.spark.sql.catalyst.rules.Rule
2524
import org.apache.spark.sql.catalyst.util.DateTimeUtils
25+
import org.apache.spark.sql.internal.SQLConf
2626
import org.apache.spark.sql.types.{StringType, TimestampType}
2727

28-
abstract class BaseAdjustTimestampsRule(sparkSession: SparkSession) extends Rule[LogicalPlan] {
28+
/**
29+
* Apply a correction to data loaded from, or saved to, tables that have a configured time zone, so
30+
* that timestamps can be read like TIMESTAMP WITHOUT TIMEZONE. This gives correct behavior if you
31+
* process data with machines in different timezones, or if you access the data from multiple SQL
32+
* engines.
33+
*/
34+
case class AdjustTimestamps(conf: SQLConf) extends Rule[LogicalPlan] {
2935

30-
/**
31-
* Apply the correction to all timestamp inputs, and replace all references to the raw attributes
32-
* with the new converted inputs.
33-
* @return The converted plan, and the replacements to be applied further up the plan
34-
*/
35-
protected def convertInputs(
36-
plan: LogicalPlan
37-
): (LogicalPlan, Map[ExprId, NamedExpression]) = plan match {
38-
case alreadyConverted@Project(exprs, _) if hasCorrection(exprs) =>
39-
(alreadyConverted, Map())
36+
def apply(plan: LogicalPlan): LogicalPlan = plan match {
37+
case insert: InsertIntoHadoopFsRelationCommand =>
38+
val adjusted = adjustTimestampsForWrite(insert.query, insert.catalogTable, insert.options)
39+
insert.copy(query = adjusted)
4040

41-
case lr @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) =>
42-
val tzOpt = extractTableTz(lr.catalogTable, fsRelation.options)
43-
tzOpt.flatMap { tableTz =>
44-
// the table has a timezone set, so after reading the data, apply a conversion
45-
46-
// SessionTZ (instead of JVM TZ) will make the time display correctly in SQL queries, but
47-
// incorrectly if you pull Timestamp objects out (eg. with a dataset.collect())
48-
val toTz = sparkSession.sessionState.conf.sessionLocalTimeZone
49-
if (toTz != tableTz) {
50-
logDebug(s"table tz = $tableTz; converting to current session tz = $toTz")
51-
// find timestamp columns, and convert their tz
52-
convertTzForAllTimestamps(lr, tableTz, toTz).map { case (fields, replacements) =>
53-
(new Project(fields, lr), replacements)
54-
}
55-
} else {
56-
None
57-
}
58-
}.getOrElse((lr, Map()))
59-
60-
case relation @ HiveTableRelation(table, _, _) =>
61-
val tzOpt = extractTableTz(Some(table), Map())
62-
tzOpt.flatMap { tz =>
63-
val toTz = sparkSession.sessionState.conf.sessionLocalTimeZone
64-
if (toTz != tz) {
65-
logDebug(s"table tz = $tz; converting to current session tz = $toTz")
66-
// find timestamp columns, and convert their tz
67-
convertTzForAllTimestamps(relation, tz, toTz).map { case (fields, replacements) =>
68-
(new Project(fields, relation), replacements)
69-
}
70-
} else {
71-
None
72-
}
73-
}.getOrElse((relation, Map()))
41+
case insert @ InsertIntoTable(table: HiveTableRelation, _, query, _, _) =>
42+
val adjusted = adjustTimestampsForWrite(insert.query, Some(table.tableMeta), Map())
43+
insert.copy(query = adjusted)
7444

7545
case other =>
76-
// first, process all the children -- this ensures we have the right renames in scope.
77-
var newReplacements = Map[ExprId, NamedExpression]()
78-
val fixedPlan = other.mapChildren { originalPlan =>
79-
val (newPlan, extraReplacements) = convertInputs(originalPlan)
80-
newReplacements ++= extraReplacements
81-
newPlan
82-
}
83-
// now we need to adjust all names to use the new version.
84-
val fixedExpressions = fixedPlan.mapExpressions { outerExp =>
85-
val adjustedExp = outerExp.transformUp { case exp: NamedExpression =>
86-
try {
87-
newReplacements.get(exp.exprId).getOrElse(exp)
88-
} catch {
89-
// UnresolvedAttributes etc. will cause problems later anyway, we just dont' want to
90-
// expose the error here
91-
case ue: UnresolvedException[_] => exp
92-
}
93-
}
94-
adjustedExp
95-
}
96-
(fixedExpressions, newReplacements)
46+
convertInputs(plan)
9747
}
9848

99-
protected def hasCorrection(exprs: Seq[Expression]): Boolean = {
100-
exprs.exists { expr =>
101-
expr.isInstanceOf[TimestampTimezoneCorrection] || hasCorrection(expr.children)
102-
}
49+
private def convertInputs(plan: LogicalPlan): LogicalPlan = plan match {
50+
case adjusted @ Project(exprs, _) if hasCorrection(exprs) =>
51+
adjusted
52+
53+
case lr @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) =>
54+
adjustTimestamps(lr, lr.catalogTable, fsRelation.options, true)
55+
56+
case hr @ HiveTableRelation(table, _, _) =>
57+
adjustTimestamps(hr, Some(table), Map(), true)
58+
59+
case other =>
60+
other.mapChildren { originalPlan =>
61+
convertInputs(originalPlan)
62+
}
10363
}
10464

105-
protected def writeConversion(
65+
private def adjustTimestamps(
66+
plan: LogicalPlan,
10667
table: Option[CatalogTable],
10768
options: Map[String, String],
108-
query: LogicalPlan): LogicalPlan = {
109-
val tableTz = extractTableTz(table, options)
110-
val internalTz = sparkSession.sessionState.conf.sessionLocalTimeZone
111-
if (tableTz.isDefined && tableTz != internalTz) {
112-
convertTzForAllTimestamps(query, internalTz, tableTz.get).map { case (fields, _) =>
113-
new Project(fields, query)
114-
}.getOrElse(query)
115-
} else {
116-
query
117-
}
118-
}
119-
120-
protected def extractTableTz(options: Map[String, String]): Option[String] = {
121-
options.get(DateTimeUtils.TIMEZONE_PROPERTY)
69+
reading: Boolean): LogicalPlan = {
70+
val tableTz = table.flatMap(_.properties.get(DateTimeUtils.TIMEZONE_PROPERTY))
71+
.orElse(options.get(DateTimeUtils.TIMEZONE_PROPERTY))
72+
73+
tableTz.map { tz =>
74+
val sessionTz = conf.sessionLocalTimeZone
75+
val toTz = if (reading) sessionTz else tz
76+
val fromTz = if (reading) tz else sessionTz
77+
logDebug(
78+
s"table tz = $tz; converting ${if (reading) "to" else "from"} session tz = $sessionTz\n")
79+
80+
var hasTimestamp = false
81+
val adjusted = plan.expressions.map {
82+
case e: NamedExpression if e.dataType == TimestampType =>
83+
val adjustment = TimestampTimezoneCorrection(e.toAttribute,
84+
Literal.create(fromTz, StringType), Literal.create(toTz, StringType))
85+
hasTimestamp = true
86+
Alias(adjustment, e.name)()
87+
88+
case other: NamedExpression =>
89+
other
90+
91+
case unnamed =>
92+
throw new AnalysisException(s"Unexpected expr: $unnamed")
93+
}.toList
94+
95+
if (hasTimestamp) Project(adjusted, plan) else plan
96+
}.getOrElse(plan)
12297
}
12398

124-
protected def extractTableTz(
99+
private def adjustTimestampsForWrite(
100+
query: LogicalPlan,
125101
table: Option[CatalogTable],
126-
options: Map[String, String]): Option[String] = {
127-
table.flatMap { tbl => extractTableTz(tbl.properties) }.orElse(extractTableTz(options))
102+
options: Map[String, String]): LogicalPlan = query match {
103+
case unadjusted if !hasOutputCorrection(unadjusted.expressions) =>
104+
// The query might be reading from a table with a configured time zone; this makes sure we
105+
// apply the correct conversions for that data.
106+
val fixedInputs = convertInputs(unadjusted)
107+
adjustTimestamps(fixedInputs, table, options, false)
108+
109+
case _ =>
110+
query
128111
}
129112

130-
/**
131-
* Find all timestamp fields in the given relation. For each one, replace it with an expression
132-
* that converts the timezone of the timestamp, and assigns an alias to that new expression.
133-
* (Leave non-timestamp fields alone.) Also return a map from the original id for the timestamp
134-
* field, to the new alias of the timezone-corrected expression.
135-
*/
136-
protected def convertTzForAllTimestamps(
137-
relation: LogicalPlan,
138-
fromTz: String,
139-
toTz: String): Option[(Seq[NamedExpression], Map[ExprId, NamedExpression])] = {
140-
val schema = relation.schema
141-
var foundTs = false
142-
var replacements = Map[ExprId, NamedExpression]()
143-
val modifiedFields: Seq[NamedExpression] = schema.map { field =>
144-
val exp = relation.resolve(Seq(field.name), sparkSession.sessionState.conf.resolver)
145-
.getOrElse {
146-
val inputColumns = schema.map(_.name).mkString(", ")
147-
throw new AnalysisException(
148-
s"cannot resolve '${field.name}' given input columns: [$inputColumns]")
149-
}
150-
if (field.dataType == TimestampType) {
151-
foundTs = true
152-
val adjustedTs = Alias(
153-
TimestampTimezoneCorrection(
154-
exp,
155-
Literal.create(fromTz, StringType),
156-
Literal.create(toTz, StringType)
157-
),
158-
field.name
159-
)()
160-
// we also need to rename all occurrences of this field further up in the plan
161-
// to refer to our new adjusted timestamp, so we pass this replacement up the call stack.
162-
replacements += exp.exprId -> adjustedTs.toAttribute
163-
adjustedTs
164-
} else {
165-
exp
166-
}
113+
private def hasCorrection(exprs: Seq[Expression]): Boolean = {
114+
exprs.exists { expr =>
115+
expr.isInstanceOf[TimestampTimezoneCorrection] || hasCorrection(expr.children)
167116
}
168-
if (foundTs) Some((modifiedFields, replacements)) else None
169117
}
170-
}
171118

172-
/**
173-
* Apply a correction to data loaded from, or saved to, tables that have a configured time zone, so
174-
* that timestamps can be read like TIMESTAMP WITHOUT TIMEZONE. This gives correct behavior if you
175-
* process data with machines in different timezones, or if you access the data from multiple SQL
176-
* engines.
177-
*/
178-
case class AdjustTimestamps(sparkSession: SparkSession)
179-
extends BaseAdjustTimestampsRule(sparkSession) {
180-
181-
def apply(plan: LogicalPlan): LogicalPlan = {
182-
// we can't use transformUp because we want to terminate recursion if there was already
183-
// timestamp correction, to keep this idempotent.
184-
plan match {
185-
case insert: InsertIntoHadoopFsRelationCommand =>
186-
// The query might be reading from a parquet table which requires a different conversion;
187-
// this makes sure we apply the correct conversions there.
188-
val (fixedQuery, _) = convertInputs(insert.query)
189-
val fixedOutput = writeConversion(insert.catalogTable, insert.options, fixedQuery)
190-
insert.copy(query = fixedOutput)
191-
192-
case other =>
193-
// recurse into children to see if we're reading data that needs conversion
194-
val (convertedPlan, _) = convertInputs(plan)
195-
convertedPlan
119+
private def hasOutputCorrection(exprs: Seq[Expression]): Boolean = {
120+
// Output correction is any TimestampTimezoneCorrection that converts from the current
121+
// session's time zone.
122+
val sessionTz = conf.sessionLocalTimeZone
123+
exprs.exists {
124+
case TimestampTimezoneCorrection(_, from, _) => from.toString() == sessionTz
125+
case other => hasOutputCorrection(other.children)
196126
}
197127
}
198128

sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,13 +158,13 @@ abstract class BaseSessionStateBuilder(
158158
override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
159159
new FindDataSourceTable(session) +:
160160
new ResolveSQLOnFile(session) +:
161+
AdjustTimestamps(conf) +:
161162
customResolutionRules
162163

163164
override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
164165
PreprocessTableCreation(session) +:
165166
PreprocessTableInsertion(conf) +:
166167
DataSourceAnalysis(conf) +:
167-
AdjustTimestamps(session) +:
168168
customPostHocResolutionRules
169169

170170
override val extendedCheckRules: Seq[LogicalPlan => Unit] =

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session
7171
new ResolveHiveSerdeTable(session) +:
7272
new FindDataSourceTable(session) +:
7373
new ResolveSQLOnFile(session) +:
74+
AdjustTimestamps(conf) +:
7475
customResolutionRules
7576

7677
override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
@@ -80,8 +81,6 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session
8081
PreprocessTableInsertion(conf) +:
8182
DataSourceAnalysis(conf) +:
8283
HiveAnalysis +:
83-
HiveAdjustTimestamps(session) +:
84-
AdjustTimestamps(session) +:
8584
customPostHocResolutionRules
8685

8786
override val extendedCheckRules: Seq[LogicalPlan => Unit] =

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoTab
3131
import org.apache.spark.sql.catalyst.rules.Rule
3232
import org.apache.spark.sql.execution._
3333
import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils}
34-
import org.apache.spark.sql.execution.datasources.{BaseAdjustTimestampsRule, CreateTable, LogicalRelation}
34+
import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation}
3535
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions}
3636
import org.apache.spark.sql.hive.execution._
3737
import org.apache.spark.sql.hive.orc.OrcFileFormat
@@ -217,33 +217,6 @@ case class RelationConversions(
217217
}
218218
}
219219

220-
/**
221-
* Apply a correction to data loaded from, or saved to, tables that have a configured time zone, so
222-
* that timestamps can be read like TIMESTAMP WITHOUT TIMEZONE. This gives correct behavior if you
223-
* process data with machines in different timezones, or if you access the data from multiple SQL
224-
* engines.
225-
*/
226-
case class HiveAdjustTimestamps(sparkSession: SparkSession)
227-
extends BaseAdjustTimestampsRule(sparkSession) {
228-
229-
def apply(plan: LogicalPlan): LogicalPlan = {
230-
// we can't use transformUp because we want to terminate recursion if there was already
231-
// timestamp correction, to keep this idempotent.
232-
plan match {
233-
case insert: InsertIntoHiveTable =>
234-
// The query might be reading from a parquet table which requires a different conversion;
235-
// this makes sure we apply the correct conversions there.
236-
val (fixedQuery, _) = convertInputs(insert.query)
237-
val fixedOutput = writeConversion(Some(insert.table), Map(), fixedQuery)
238-
insert.copy(query = fixedOutput)
239-
240-
case other =>
241-
plan
242-
}
243-
}
244-
245-
}
246-
247220
private[hive] trait HiveStrategies {
248221
// Possibly being too clever with types here... or not clever enough.
249222
self: SparkPlanner =>

0 commit comments

Comments
 (0)