|
16 | 16 | */ |
17 | 17 | package org.apache.spark.sql.execution.datasources |
18 | 18 |
|
19 | | -import org.apache.spark.sql.{AnalysisException, SparkSession} |
20 | | -import org.apache.spark.sql.catalyst.analysis.UnresolvedException |
| 19 | +import org.apache.spark.sql.{AnalysisException} |
21 | 20 | import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HiveTableRelation} |
22 | 21 | import org.apache.spark.sql.catalyst.expressions._ |
23 | | -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} |
| 22 | +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, Project} |
24 | 23 | import org.apache.spark.sql.catalyst.rules.Rule |
25 | 24 | import org.apache.spark.sql.catalyst.util.DateTimeUtils |
| 25 | +import org.apache.spark.sql.internal.SQLConf |
26 | 26 | import org.apache.spark.sql.types.{StringType, TimestampType} |
27 | 27 |
|
28 | | -abstract class BaseAdjustTimestampsRule(sparkSession: SparkSession) extends Rule[LogicalPlan] { |
| 28 | +/** |
| 29 | + * Apply a correction to data loaded from, or saved to, tables that have a configured time zone, so |
| 30 | + * that timestamps can be read like TIMESTAMP WITHOUT TIMEZONE. This gives correct behavior if you |
| 31 | + * process data with machines in different timezones, or if you access the data from multiple SQL |
| 32 | + * engines. |
| 33 | + */ |
| 34 | +case class AdjustTimestamps(conf: SQLConf) extends Rule[LogicalPlan] { |
29 | 35 |
|
30 | | - /** |
31 | | - * Apply the correction to all timestamp inputs, and replace all references to the raw attributes |
32 | | - * with the new converted inputs. |
33 | | - * @return The converted plan, and the replacements to be applied further up the plan |
34 | | - */ |
35 | | - protected def convertInputs( |
36 | | - plan: LogicalPlan |
37 | | - ): (LogicalPlan, Map[ExprId, NamedExpression]) = plan match { |
38 | | - case alreadyConverted@Project(exprs, _) if hasCorrection(exprs) => |
39 | | - (alreadyConverted, Map()) |
| 36 | + def apply(plan: LogicalPlan): LogicalPlan = plan match { |
| 37 | + case insert: InsertIntoHadoopFsRelationCommand => |
| 38 | + val adjusted = adjustTimestampsForWrite(insert.query, insert.catalogTable, insert.options) |
| 39 | + insert.copy(query = adjusted) |
40 | 40 |
|
41 | | - case lr @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) => |
42 | | - val tzOpt = extractTableTz(lr.catalogTable, fsRelation.options) |
43 | | - tzOpt.flatMap { tableTz => |
44 | | - // the table has a timezone set, so after reading the data, apply a conversion |
45 | | - |
46 | | - // SessionTZ (instead of JVM TZ) will make the time display correctly in SQL queries, but |
47 | | - // incorrectly if you pull Timestamp objects out (eg. with a dataset.collect()) |
48 | | - val toTz = sparkSession.sessionState.conf.sessionLocalTimeZone |
49 | | - if (toTz != tableTz) { |
50 | | - logDebug(s"table tz = $tableTz; converting to current session tz = $toTz") |
51 | | - // find timestamp columns, and convert their tz |
52 | | - convertTzForAllTimestamps(lr, tableTz, toTz).map { case (fields, replacements) => |
53 | | - (new Project(fields, lr), replacements) |
54 | | - } |
55 | | - } else { |
56 | | - None |
57 | | - } |
58 | | - }.getOrElse((lr, Map())) |
59 | | - |
60 | | - case relation @ HiveTableRelation(table, _, _) => |
61 | | - val tzOpt = extractTableTz(Some(table), Map()) |
62 | | - tzOpt.flatMap { tz => |
63 | | - val toTz = sparkSession.sessionState.conf.sessionLocalTimeZone |
64 | | - if (toTz != tz) { |
65 | | - logDebug(s"table tz = $tz; converting to current session tz = $toTz") |
66 | | - // find timestamp columns, and convert their tz |
67 | | - convertTzForAllTimestamps(relation, tz, toTz).map { case (fields, replacements) => |
68 | | - (new Project(fields, relation), replacements) |
69 | | - } |
70 | | - } else { |
71 | | - None |
72 | | - } |
73 | | - }.getOrElse((relation, Map())) |
| 41 | + case insert @ InsertIntoTable(table: HiveTableRelation, _, query, _, _) => |
| 42 | + val adjusted = adjustTimestampsForWrite(insert.query, Some(table.tableMeta), Map()) |
| 43 | + insert.copy(query = adjusted) |
74 | 44 |
|
75 | 45 | case other => |
76 | | - // first, process all the children -- this ensures we have the right renames in scope. |
77 | | - var newReplacements = Map[ExprId, NamedExpression]() |
78 | | - val fixedPlan = other.mapChildren { originalPlan => |
79 | | - val (newPlan, extraReplacements) = convertInputs(originalPlan) |
80 | | - newReplacements ++= extraReplacements |
81 | | - newPlan |
82 | | - } |
83 | | - // now we need to adjust all names to use the new version. |
84 | | - val fixedExpressions = fixedPlan.mapExpressions { outerExp => |
85 | | - val adjustedExp = outerExp.transformUp { case exp: NamedExpression => |
86 | | - try { |
87 | | - newReplacements.get(exp.exprId).getOrElse(exp) |
88 | | - } catch { |
89 | | - // UnresolvedAttributes etc. will cause problems later anyway, we just dont' want to |
90 | | - // expose the error here |
91 | | - case ue: UnresolvedException[_] => exp |
92 | | - } |
93 | | - } |
94 | | - adjustedExp |
95 | | - } |
96 | | - (fixedExpressions, newReplacements) |
| 46 | + convertInputs(plan) |
97 | 47 | } |
98 | 48 |
|
99 | | - protected def hasCorrection(exprs: Seq[Expression]): Boolean = { |
100 | | - exprs.exists { expr => |
101 | | - expr.isInstanceOf[TimestampTimezoneCorrection] || hasCorrection(expr.children) |
102 | | - } |
| 49 | + private def convertInputs(plan: LogicalPlan): LogicalPlan = plan match { |
| 50 | + case adjusted @ Project(exprs, _) if hasCorrection(exprs) => |
| 51 | + adjusted |
| 52 | + |
| 53 | + case lr @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) => |
| 54 | + adjustTimestamps(lr, lr.catalogTable, fsRelation.options, true) |
| 55 | + |
| 56 | + case hr @ HiveTableRelation(table, _, _) => |
| 57 | + adjustTimestamps(hr, Some(table), Map(), true) |
| 58 | + |
| 59 | + case other => |
| 60 | + other.mapChildren { originalPlan => |
| 61 | + convertInputs(originalPlan) |
| 62 | + } |
103 | 63 | } |
104 | 64 |
|
105 | | - protected def writeConversion( |
| 65 | + private def adjustTimestamps( |
| 66 | + plan: LogicalPlan, |
106 | 67 | table: Option[CatalogTable], |
107 | 68 | options: Map[String, String], |
108 | | - query: LogicalPlan): LogicalPlan = { |
109 | | - val tableTz = extractTableTz(table, options) |
110 | | - val internalTz = sparkSession.sessionState.conf.sessionLocalTimeZone |
111 | | - if (tableTz.isDefined && tableTz != internalTz) { |
112 | | - convertTzForAllTimestamps(query, internalTz, tableTz.get).map { case (fields, _) => |
113 | | - new Project(fields, query) |
114 | | - }.getOrElse(query) |
115 | | - } else { |
116 | | - query |
117 | | - } |
118 | | - } |
119 | | - |
120 | | - protected def extractTableTz(options: Map[String, String]): Option[String] = { |
121 | | - options.get(DateTimeUtils.TIMEZONE_PROPERTY) |
| 69 | + reading: Boolean): LogicalPlan = { |
| 70 | + val tableTz = table.flatMap(_.properties.get(DateTimeUtils.TIMEZONE_PROPERTY)) |
| 71 | + .orElse(options.get(DateTimeUtils.TIMEZONE_PROPERTY)) |
| 72 | + |
| 73 | + tableTz.map { tz => |
| 74 | + val sessionTz = conf.sessionLocalTimeZone |
| 75 | + val toTz = if (reading) sessionTz else tz |
| 76 | + val fromTz = if (reading) tz else sessionTz |
| 77 | + logDebug( |
| 78 | + s"table tz = $tz; converting ${if (reading) "to" else "from"} session tz = $sessionTz\n") |
| 79 | + |
| 80 | + var hasTimestamp = false |
| 81 | + val adjusted = plan.expressions.map { |
| 82 | + case e: NamedExpression if e.dataType == TimestampType => |
| 83 | + val adjustment = TimestampTimezoneCorrection(e.toAttribute, |
| 84 | + Literal.create(fromTz, StringType), Literal.create(toTz, StringType)) |
| 85 | + hasTimestamp = true |
| 86 | + Alias(adjustment, e.name)() |
| 87 | + |
| 88 | + case other: NamedExpression => |
| 89 | + other |
| 90 | + |
| 91 | + case unnamed => |
| 92 | + throw new AnalysisException(s"Unexpected expr: $unnamed") |
| 93 | + }.toList |
| 94 | + |
| 95 | + if (hasTimestamp) Project(adjusted, plan) else plan |
| 96 | + }.getOrElse(plan) |
122 | 97 | } |
123 | 98 |
|
124 | | - protected def extractTableTz( |
| 99 | + private def adjustTimestampsForWrite( |
| 100 | + query: LogicalPlan, |
125 | 101 | table: Option[CatalogTable], |
126 | | - options: Map[String, String]): Option[String] = { |
127 | | - table.flatMap { tbl => extractTableTz(tbl.properties) }.orElse(extractTableTz(options)) |
| 102 | + options: Map[String, String]): LogicalPlan = query match { |
| 103 | + case unadjusted if !hasOutputCorrection(unadjusted.expressions) => |
| 104 | + // The query might be reading from a table with a configured time zone; this makes sure we |
| 105 | + // apply the correct conversions for that data. |
| 106 | + val fixedInputs = convertInputs(unadjusted) |
| 107 | + adjustTimestamps(fixedInputs, table, options, false) |
| 108 | + |
| 109 | + case _ => |
| 110 | + query |
128 | 111 | } |
129 | 112 |
|
130 | | - /** |
131 | | - * Find all timestamp fields in the given relation. For each one, replace it with an expression |
132 | | - * that converts the timezone of the timestamp, and assigns an alias to that new expression. |
133 | | - * (Leave non-timestamp fields alone.) Also return a map from the original id for the timestamp |
134 | | - * field, to the new alias of the timezone-corrected expression. |
135 | | - */ |
136 | | - protected def convertTzForAllTimestamps( |
137 | | - relation: LogicalPlan, |
138 | | - fromTz: String, |
139 | | - toTz: String): Option[(Seq[NamedExpression], Map[ExprId, NamedExpression])] = { |
140 | | - val schema = relation.schema |
141 | | - var foundTs = false |
142 | | - var replacements = Map[ExprId, NamedExpression]() |
143 | | - val modifiedFields: Seq[NamedExpression] = schema.map { field => |
144 | | - val exp = relation.resolve(Seq(field.name), sparkSession.sessionState.conf.resolver) |
145 | | - .getOrElse { |
146 | | - val inputColumns = schema.map(_.name).mkString(", ") |
147 | | - throw new AnalysisException( |
148 | | - s"cannot resolve '${field.name}' given input columns: [$inputColumns]") |
149 | | - } |
150 | | - if (field.dataType == TimestampType) { |
151 | | - foundTs = true |
152 | | - val adjustedTs = Alias( |
153 | | - TimestampTimezoneCorrection( |
154 | | - exp, |
155 | | - Literal.create(fromTz, StringType), |
156 | | - Literal.create(toTz, StringType) |
157 | | - ), |
158 | | - field.name |
159 | | - )() |
160 | | - // we also need to rename all occurrences of this field further up in the plan |
161 | | - // to refer to our new adjusted timestamp, so we pass this replacement up the call stack. |
162 | | - replacements += exp.exprId -> adjustedTs.toAttribute |
163 | | - adjustedTs |
164 | | - } else { |
165 | | - exp |
166 | | - } |
| 113 | + private def hasCorrection(exprs: Seq[Expression]): Boolean = { |
| 114 | + exprs.exists { expr => |
| 115 | + expr.isInstanceOf[TimestampTimezoneCorrection] || hasCorrection(expr.children) |
167 | 116 | } |
168 | | - if (foundTs) Some((modifiedFields, replacements)) else None |
169 | 117 | } |
170 | | -} |
171 | 118 |
|
172 | | -/** |
173 | | - * Apply a correction to data loaded from, or saved to, tables that have a configured time zone, so |
174 | | - * that timestamps can be read like TIMESTAMP WITHOUT TIMEZONE. This gives correct behavior if you |
175 | | - * process data with machines in different timezones, or if you access the data from multiple SQL |
176 | | - * engines. |
177 | | - */ |
178 | | -case class AdjustTimestamps(sparkSession: SparkSession) |
179 | | - extends BaseAdjustTimestampsRule(sparkSession) { |
180 | | - |
181 | | - def apply(plan: LogicalPlan): LogicalPlan = { |
182 | | - // we can't use transformUp because we want to terminate recursion if there was already |
183 | | - // timestamp correction, to keep this idempotent. |
184 | | - plan match { |
185 | | - case insert: InsertIntoHadoopFsRelationCommand => |
186 | | - // The query might be reading from a parquet table which requires a different conversion; |
187 | | - // this makes sure we apply the correct conversions there. |
188 | | - val (fixedQuery, _) = convertInputs(insert.query) |
189 | | - val fixedOutput = writeConversion(insert.catalogTable, insert.options, fixedQuery) |
190 | | - insert.copy(query = fixedOutput) |
191 | | - |
192 | | - case other => |
193 | | - // recurse into children to see if we're reading data that needs conversion |
194 | | - val (convertedPlan, _) = convertInputs(plan) |
195 | | - convertedPlan |
| 119 | + private def hasOutputCorrection(exprs: Seq[Expression]): Boolean = { |
| 120 | + // Output correction is any TimestampTimezoneCorrection that converts from the current |
| 121 | + // session's time zone. |
| 122 | + val sessionTz = conf.sessionLocalTimeZone |
| 123 | + exprs.exists { |
| 124 | + case TimestampTimezoneCorrection(_, from, _) => from.toString() == sessionTz |
| 125 | + case other => hasOutputCorrection(other.children) |
196 | 126 | } |
197 | 127 | } |
198 | 128 |
|
|
0 commit comments