@@ -41,7 +41,7 @@ object HiveTypeCoercion {
4141 * with primitive types, because in that case the precision and scale of the result depends on
4242 * the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision ]].
4343 */
44- val findTightestCommonType : (DataType , DataType ) => Option [DataType ] = {
44+ val findTightestCommonTypeOfTwo : (DataType , DataType ) => Option [DataType ] = {
4545 case (t1, t2) if t1 == t2 => Some (t1)
4646 case (NullType , t1) => Some (t1)
4747 case (t1, NullType ) => Some (t1)
@@ -57,6 +57,17 @@ object HiveTypeCoercion {
5757
5858 case _ => None
5959 }
60+
61+ /**
62+ * Find the tightest common type of a set of types by continuously applying
63+ * `findTightestCommonTypeOfTwo` on these types.
64+ */
65+ private def findTightestCommonType (types : Seq [DataType ]) = {
66+ types.foldLeft[Option [DataType ]](Some (NullType ))((r, c) => r match {
67+ case None => None
68+ case Some (d) => findTightestCommonTypeOfTwo(d, c)
69+ })
70+ }
6071}
6172
6273/**
@@ -180,7 +191,7 @@ trait HiveTypeCoercion {
180191
181192 case (l, r) if l.dataType != r.dataType =>
182193 logDebug(s " Resolving mismatched union input ${l.dataType}, ${r.dataType}" )
183- findTightestCommonType (l.dataType, r.dataType).map { widestType =>
194+ findTightestCommonTypeOfTwo (l.dataType, r.dataType).map { widestType =>
184195 val newLeft =
185196 if (l.dataType == widestType) l else Alias (Cast (l, widestType), l.name)()
186197 val newRight =
@@ -217,7 +228,7 @@ trait HiveTypeCoercion {
217228 case e if ! e.childrenResolved => e
218229
219230 case b : BinaryExpression if b.left.dataType != b.right.dataType =>
220- findTightestCommonType (b.left.dataType, b.right.dataType).map { widestType =>
231+ findTightestCommonTypeOfTwo (b.left.dataType, b.right.dataType).map { widestType =>
221232 val newLeft =
222233 if (b.left.dataType == widestType) b.left else Cast (b.left, widestType)
223234 val newRight =
@@ -441,21 +452,18 @@ trait HiveTypeCoercion {
441452 DecimalType (min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
442453 )
443454
444- case LessThan (e1 @ DecimalType .Expression (p1, s1),
445- e2 @ DecimalType .Expression (p2, s2)) if p1 != p2 || s1 != s2 =>
446- LessThan (Cast (e1, DecimalType .Unlimited ), Cast (e2, DecimalType .Unlimited ))
447-
448- case LessThanOrEqual (e1 @ DecimalType .Expression (p1, s1),
449- e2 @ DecimalType .Expression (p2, s2)) if p1 != p2 || s1 != s2 =>
450- LessThanOrEqual (Cast (e1, DecimalType .Unlimited ), Cast (e2, DecimalType .Unlimited ))
451-
452- case GreaterThan (e1 @ DecimalType .Expression (p1, s1),
453- e2 @ DecimalType .Expression (p2, s2)) if p1 != p2 || s1 != s2 =>
454- GreaterThan (Cast (e1, DecimalType .Unlimited ), Cast (e2, DecimalType .Unlimited ))
455-
456- case GreaterThanOrEqual (e1 @ DecimalType .Expression (p1, s1),
457- e2 @ DecimalType .Expression (p2, s2)) if p1 != p2 || s1 != s2 =>
458- GreaterThanOrEqual (Cast (e1, DecimalType .Unlimited ), Cast (e2, DecimalType .Unlimited ))
455+ // When we compare 2 decimal types with different precisions, cast them to the smallest
456+ // common precision.
457+ case b @ BinaryComparison (e1 @ DecimalType .Expression (p1, s1),
458+ e2 @ DecimalType .Expression (p2, s2)) if p1 != p2 || s1 != s2 =>
459+ val resultType = DecimalType (max(p1, p2), max(s1, s2))
460+ b.makeCopy(Array (Cast (e1, resultType), Cast (e2, resultType)))
461+ case b @ BinaryComparison (e1 @ DecimalType .Fixed (_, _), e2)
462+ if e2.dataType == DecimalType .Unlimited =>
463+ b.makeCopy(Array (Cast (e1, DecimalType .Unlimited ), e2))
464+ case b @ BinaryComparison (e1, e2 @ DecimalType .Fixed (_, _))
465+ if e1.dataType == DecimalType .Unlimited =>
466+ b.makeCopy(Array (e1, Cast (e2, DecimalType .Unlimited )))
459467
460468 // Promote integers inside a binary expression with fixed-precision decimals to decimals,
461469 // and fixed-precision decimals in an expression with floats / doubles to doubles
@@ -570,7 +578,7 @@ trait HiveTypeCoercion {
570578
571579 case a @ CreateArray (children) if ! a.resolved =>
572580 val commonType = a.childTypes.reduce(
573- (a, b) => findTightestCommonType (a, b).getOrElse(StringType ))
581+ (a, b) => findTightestCommonTypeOfTwo (a, b).getOrElse(StringType ))
574582 CreateArray (
575583 children.map(c => if (c.dataType == commonType) c else Cast (c, commonType)))
576584
@@ -599,14 +607,9 @@ trait HiveTypeCoercion {
599607 // from the list. So we need to make sure the return type is deterministic and
600608 // compatible with every child column.
601609 case Coalesce (es) if es.map(_.dataType).distinct.size > 1 =>
602- val dt : Option [DataType ] = Some (NullType )
603610 val types = es.map(_.dataType)
604- val rt = types.foldLeft(dt)((r, c) => r match {
605- case None => None
606- case Some (d) => findTightestCommonType(d, c)
607- })
608- rt match {
609- case Some (finaldt) => Coalesce (es.map(Cast (_, finaldt)))
611+ findTightestCommonType(types) match {
612+ case Some (finalDataType) => Coalesce (es.map(Cast (_, finalDataType)))
610613 case None =>
611614 sys.error(s " Could not determine return type of Coalesce for ${types.mkString(" ," )}" )
612615 }
@@ -619,17 +622,13 @@ trait HiveTypeCoercion {
619622 */
620623 object Division extends Rule [LogicalPlan ] {
621624 def apply (plan : LogicalPlan ): LogicalPlan = plan transformAllExpressions {
622- // Skip nodes who's children have not been resolved yet.
623- case e if ! e.childrenResolved => e
625+ // Skip nodes who has not been resolved yet,
626+ // as this is an extra rule which should be applied at last.
627+ case e if ! e.resolved => e
624628
625629 // Decimal and Double remain the same
626- case d : Divide if d.resolved && d.dataType == DoubleType => d
627- case d : Divide if d.resolved && d.dataType.isInstanceOf [DecimalType ] => d
628-
629- case Divide (l, r) if l.dataType.isInstanceOf [DecimalType ] =>
630- Divide (l, Cast (r, DecimalType .Unlimited ))
631- case Divide (l, r) if r.dataType.isInstanceOf [DecimalType ] =>
632- Divide (Cast (l, DecimalType .Unlimited ), r)
630+ case d : Divide if d.dataType == DoubleType => d
631+ case d : Divide if d.dataType.isInstanceOf [DecimalType ] => d
633632
634633 case Divide (l, r) => Divide (Cast (l, DoubleType ), Cast (r, DoubleType ))
635634 }
@@ -642,42 +641,33 @@ trait HiveTypeCoercion {
642641 import HiveTypeCoercion ._
643642
644643 def apply (plan : LogicalPlan ): LogicalPlan = plan transformAllExpressions {
645- case cw : CaseWhenLike if cw.childrenResolved && ! cw.valueTypesEqual =>
646- logDebug(s " Input values for null casting ${cw.valueTypes.mkString(" ," )}" )
647- val commonType = cw.valueTypes.reduce { (v1, v2) =>
648- findTightestCommonType(v1, v2).getOrElse(sys.error(
649- s " Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2" ))
650- }
651- val transformedBranches = cw.branches.sliding(2 , 2 ).map {
652- case Seq (when, value) if value.dataType != commonType =>
653- Seq (when, Cast (value, commonType))
654- case Seq (elseVal) if elseVal.dataType != commonType =>
655- Seq (Cast (elseVal, commonType))
656- case s => s
657- }.reduce(_ ++ _)
658- cw match {
659- case _ : CaseWhen =>
660- CaseWhen (transformedBranches)
661- case CaseKeyWhen (key, _) =>
662- CaseKeyWhen (key, transformedBranches)
663- }
664-
665- case ckw : CaseKeyWhen if ckw.childrenResolved && ! ckw.resolved =>
666- val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) =>
667- findTightestCommonType(v1, v2).getOrElse(sys.error(
668- s " Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2" ))
669- }
670- val transformedBranches = ckw.branches.sliding(2 , 2 ).map {
671- case Seq (when, then ) if when.dataType != commonType =>
672- Seq (Cast (when, commonType), then )
673- case s => s
674- }.reduce(_ ++ _)
675- val transformedKey = if (ckw.key.dataType != commonType) {
676- Cast (ckw.key, commonType)
677- } else {
678- ckw.key
679- }
680- CaseKeyWhen (transformedKey, transformedBranches)
644+ case c : CaseWhenLike if c.childrenResolved && ! c.valueTypesEqual =>
645+ logDebug(s " Input values for null casting ${c.valueTypes.mkString(" ," )}" )
646+ val maybeCommonType = findTightestCommonType(c.valueTypes)
647+ maybeCommonType.map { commonType =>
648+ val castedBranches = c.branches.grouped(2 ).map {
649+ case Seq (when, value) if value.dataType != commonType =>
650+ Seq (when, Cast (value, commonType))
651+ case Seq (elseVal) if elseVal.dataType != commonType =>
652+ Seq (Cast (elseVal, commonType))
653+ case other => other
654+ }.reduce(_ ++ _)
655+ c match {
656+ case _ : CaseWhen => CaseWhen (castedBranches)
657+ case CaseKeyWhen (key, _) => CaseKeyWhen (key, castedBranches)
658+ }
659+ }.getOrElse(c)
660+
661+ case c : CaseKeyWhen if c.childrenResolved && ! c.resolved =>
662+ val maybeCommonType = findTightestCommonType((c.key +: c.whenList).map(_.dataType))
663+ maybeCommonType.map { commonType =>
664+ val castedBranches = c.branches.grouped(2 ).map {
665+ case Seq (when, then ) if when.dataType != commonType =>
666+ Seq (Cast (when, commonType), then )
667+ case other => other
668+ }.reduce(_ ++ _)
669+ CaseKeyWhen (Cast (c.key, commonType), castedBranches)
670+ }.getOrElse(c)
681671 }
682672 }
683673
0 commit comments