@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
2020import javax .annotation .Nullable
2121
2222import org .apache .spark .sql .catalyst .expressions ._
23- import org .apache .spark .sql .catalyst .plans .logical .{ LogicalPlan , Project , Union }
23+ import org .apache .spark .sql .catalyst .plans .logical ._
2424import org .apache .spark .sql .catalyst .rules .Rule
2525import org .apache .spark .sql .types ._
2626
@@ -168,52 +168,65 @@ object HiveTypeCoercion {
168168 * - LongType to DoubleType
169169 */
170170 object WidenTypes extends Rule [LogicalPlan ] {
171- def apply (plan : LogicalPlan ): LogicalPlan = plan transform {
172- // TODO: unions with fixed-precision decimals
173- case u @ Union (left, right) if u.childrenResolved && ! u.resolved =>
174- val castedInput = left.output.zip(right.output).map {
175- // When a string is found on one side, make the other side a string too.
176- case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType =>
177- (lhs, Alias (Cast (rhs, StringType ), rhs.name)())
178- case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType =>
179- (Alias (Cast (lhs, StringType ), lhs.name)(), rhs)
180171
181- case (lhs, rhs) if lhs.dataType != rhs.dataType =>
182- logDebug(s " Resolving mismatched union input ${lhs.dataType}, ${rhs.dataType}" )
183- findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType =>
184- val newLeft =
185- if (lhs.dataType == widestType) lhs else Alias (Cast (lhs, widestType), lhs.name)()
186- val newRight =
187- if (rhs.dataType == widestType) rhs else Alias (Cast (rhs, widestType), rhs.name)()
188-
189- (newLeft, newRight)
190- }.getOrElse {
191- // If there is no applicable conversion, leave expression unchanged.
192- (lhs, rhs)
193- }
172+ private [this ] def widenOutputTypes (planName : String , left : LogicalPlan , right : LogicalPlan ):
173+ (LogicalPlan , LogicalPlan ) = {
174+
175+ // TODO: with fixed-precision decimals
176+ val castedInput = left.output.zip(right.output).map {
177+ // When a string is found on one side, make the other side a string too.
178+ case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType =>
179+ (lhs, Alias (Cast (rhs, StringType ), rhs.name)())
180+ case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType =>
181+ (Alias (Cast (lhs, StringType ), lhs.name)(), rhs)
182+
183+ case (lhs, rhs) if lhs.dataType != rhs.dataType =>
184+ logDebug(s " Resolving mismatched $planName input ${lhs.dataType}, ${rhs.dataType}" )
185+ findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType =>
186+ val newLeft =
187+ if (lhs.dataType == widestType) lhs else Alias (Cast (lhs, widestType), lhs.name)()
188+ val newRight =
189+ if (rhs.dataType == widestType) rhs else Alias (Cast (rhs, widestType), rhs.name)()
190+
191+ (newLeft, newRight)
192+ }.getOrElse {
193+ // If there is no applicable conversion, leave expression unchanged.
194+ (lhs, rhs)
195+ }
194196
195- case other => other
196- }
197+ case other => other
198+ }
197199
198- val (castedLeft, castedRight) = castedInput.unzip
200+ val (castedLeft, castedRight) = castedInput.unzip
199201
200- val newLeft =
201- if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
202- logDebug(s " Widening numeric types in union $castedLeft ${left.output}" )
203- Project (castedLeft, left)
204- } else {
205- left
206- }
202+ val newLeft =
203+ if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
204+ logDebug(s " Widening numeric types in $planName $castedLeft ${left.output}" )
205+ Project (castedLeft, left)
206+ } else {
207+ left
208+ }
207209
208- val newRight =
209- if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
210- logDebug(s " Widening numeric types in union $castedRight ${right.output}" )
211- Project (castedRight, right)
212- } else {
213- right
214- }
210+ val newRight =
211+ if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
212+ logDebug(s " Widening numeric types in $planName $castedRight ${right.output}" )
213+ Project (castedRight, right)
214+ } else {
215+ right
216+ }
217+ (newLeft, newRight)
218+ }
215219
220+ def apply (plan : LogicalPlan ): LogicalPlan = plan transform {
221+ case u @ Union (left, right) if u.childrenResolved && ! u.resolved =>
222+ val (newLeft, newRight) = widenOutputTypes(u.nodeName, left, right)
216223 Union (newLeft, newRight)
224+ case e @ Except (left, right) if e.childrenResolved && ! e.resolved =>
225+ val (newLeft, newRight) = widenOutputTypes(e.nodeName, left, right)
226+ Except (newLeft, newRight)
227+ case i @ Intersect (left, right) if i.childrenResolved && ! i.resolved =>
228+ val (newLeft, newRight) = widenOutputTypes(i.nodeName, left, right)
229+ Intersect (newLeft, newRight)
217230 }
218231 }
219232
0 commit comments