@@ -184,6 +184,17 @@ object TypeCoercion {
184184 }
185185 }
186186
187+ def findCommonTypeDifferentOnlyInNullFlags (types : Seq [DataType ]): Option [DataType ] = {
188+ if (types.isEmpty) {
189+ None
190+ } else {
191+ types.tail.foldLeft[Option [DataType ]](Some (types.head)) {
192+ case (Some (t1), t2) => findCommonTypeDifferentOnlyInNullFlags(t1, t2)
193+ case _ => None
194+ }
195+ }
196+ }
197+
187198 /**
188199 * Case 2 type widening (see the classdoc comment above for TypeCoercion).
189200 *
@@ -259,8 +270,25 @@ object TypeCoercion {
259270 }
260271 }
261272
262- private def haveSameType (exprs : Seq [Expression ]): Boolean =
263- exprs.map(_.dataType).distinct.length == 1
273+ /**
274+ * Check whether the given types are equal ignoring nullable, containsNull and valueContainsNull.
275+ */
276+ def haveSameType (types : Seq [DataType ]): Boolean = {
277+ if (types.size <= 1 ) {
278+ true
279+ } else {
280+ val head = types.head
281+ types.tail.forall(_.sameType(head))
282+ }
283+ }
284+
285+ private def castIfNotSameType (expr : Expression , dt : DataType ): Expression = {
286+ if (! expr.dataType.sameType(dt)) {
287+ Cast (expr, dt)
288+ } else {
289+ expr
290+ }
291+ }
264292
265293 /**
266294 * Widens numeric types and converts strings to numbers when appropriate.
@@ -525,23 +553,24 @@ object TypeCoercion {
525553 * This ensure that the types for various functions are as expected.
526554 */
527555 object FunctionArgumentConversion extends TypeCoercionRule {
556+
528557 override protected def coerceTypes (
529558 plan : LogicalPlan ): LogicalPlan = plan transformAllExpressions {
530559 // Skip nodes who's children have not been resolved yet.
531560 case e if ! e.childrenResolved => e
532561
533- case a @ CreateArray (children) if ! haveSameType(children) =>
562+ case a @ CreateArray (children) if ! haveSameType(children.map(_.dataType) ) =>
534563 val types = children.map(_.dataType)
535564 findWiderCommonType(types) match {
536- case Some (finalDataType) => CreateArray (children.map(Cast (_, finalDataType)))
565+ case Some (finalDataType) => CreateArray (children.map(castIfNotSameType (_, finalDataType)))
537566 case None => a
538567 }
539568
540569 case c @ Concat (children) if children.forall(c => ArrayType .acceptsType(c.dataType)) &&
541- ! haveSameType(children ) =>
570+ ! haveSameType(c.inputTypesForMerging ) =>
542571 val types = children.map(_.dataType)
543572 findWiderCommonType(types) match {
544- case Some (finalDataType) => Concat (children.map(Cast (_, finalDataType)))
573+ case Some (finalDataType) => Concat (children.map(castIfNotSameType (_, finalDataType)))
545574 case None => c
546575 }
547576
@@ -553,41 +582,34 @@ object TypeCoercion {
553582 case None => aj
554583 }
555584
556- case s @ Sequence (_, _, _, timeZoneId) if ! haveSameType(s.coercibleChildren) =>
585+ case s @ Sequence (_, _, _, timeZoneId)
586+ if ! haveSameType(s.coercibleChildren.map(_.dataType)) =>
557587 val types = s.coercibleChildren.map(_.dataType)
558588 findWiderCommonType(types) match {
559589 case Some (widerDataType) => s.castChildrenTo(widerDataType)
560590 case None => s
561591 }
562592
563593 case m @ MapConcat (children) if children.forall(c => MapType .acceptsType(c.dataType)) &&
564- ! haveSameType(children ) =>
594+ ! haveSameType(m.inputTypesForMerging ) =>
565595 val types = children.map(_.dataType)
566596 findWiderCommonType(types) match {
567- case Some (finalDataType) => MapConcat (children.map(Cast (_, finalDataType)))
597+ case Some (finalDataType) => MapConcat (children.map(castIfNotSameType (_, finalDataType)))
568598 case None => m
569599 }
570600
571601 case m @ CreateMap (children) if m.keys.length == m.values.length &&
572- (! haveSameType(m.keys) || ! haveSameType(m.values)) =>
573- val newKeys = if (haveSameType(m.keys)) {
574- m.keys
575- } else {
576- val types = m.keys.map(_.dataType)
577- findWiderCommonType(types) match {
578- case Some (finalDataType) => m.keys.map(Cast (_, finalDataType))
579- case None => m.keys
580- }
602+ (! haveSameType(m.keys.map(_.dataType)) || ! haveSameType(m.values.map(_.dataType))) =>
603+ val keyTypes = m.keys.map(_.dataType)
604+ val newKeys = findWiderCommonType(keyTypes) match {
605+ case Some (finalDataType) => m.keys.map(castIfNotSameType(_, finalDataType))
606+ case None => m.keys
581607 }
582608
583- val newValues = if (haveSameType(m.values)) {
584- m.values
585- } else {
586- val types = m.values.map(_.dataType)
587- findWiderCommonType(types) match {
588- case Some (finalDataType) => m.values.map(Cast (_, finalDataType))
589- case None => m.values
590- }
609+ val valueTypes = m.values.map(_.dataType)
610+ val newValues = findWiderCommonType(valueTypes) match {
611+ case Some (finalDataType) => m.values.map(castIfNotSameType(_, finalDataType))
612+ case None => m.values
591613 }
592614
593615 CreateMap (newKeys.zip(newValues).flatMap { case (k, v) => Seq (k, v) })
@@ -610,27 +632,27 @@ object TypeCoercion {
610632 // Coalesce should return the first non-null value, which could be any column
611633 // from the list. So we need to make sure the return type is deterministic and
612634 // compatible with every child column.
613- case c @ Coalesce (es) if ! haveSameType(es ) =>
635+ case c @ Coalesce (es) if ! haveSameType(c.inputTypesForMerging ) =>
614636 val types = es.map(_.dataType)
615637 findWiderCommonType(types) match {
616- case Some (finalDataType) => Coalesce (es.map(Cast (_, finalDataType)))
638+ case Some (finalDataType) => Coalesce (es.map(castIfNotSameType (_, finalDataType)))
617639 case None => c
618640 }
619641
620642 // When finding wider type for `Greatest` and `Least`, we should handle decimal types even if
621643 // we need to truncate, but we should not promote one side to string if the other side is
622644 // string.g
623- case g @ Greatest (children) if ! haveSameType(children ) =>
645+ case g @ Greatest (children) if ! haveSameType(g.inputTypesForMerging ) =>
624646 val types = children.map(_.dataType)
625647 findWiderTypeWithoutStringPromotion(types) match {
626- case Some (finalDataType) => Greatest (children.map(Cast (_, finalDataType)))
648+ case Some (finalDataType) => Greatest (children.map(castIfNotSameType (_, finalDataType)))
627649 case None => g
628650 }
629651
630- case l @ Least (children) if ! haveSameType(children ) =>
652+ case l @ Least (children) if ! haveSameType(l.inputTypesForMerging ) =>
631653 val types = children.map(_.dataType)
632654 findWiderTypeWithoutStringPromotion(types) match {
633- case Some (finalDataType) => Least (children.map(Cast (_, finalDataType)))
655+ case Some (finalDataType) => Least (children.map(castIfNotSameType (_, finalDataType)))
634656 case None => l
635657 }
636658
@@ -672,27 +694,14 @@ object TypeCoercion {
672694 object CaseWhenCoercion extends TypeCoercionRule {
673695 override protected def coerceTypes (
674696 plan : LogicalPlan ): LogicalPlan = plan transformAllExpressions {
675- case c : CaseWhen if c.childrenResolved && ! c.areInputTypesForMergingEqual =>
697+ case c : CaseWhen if c.childrenResolved && ! haveSameType(c.inputTypesForMerging) =>
676698 val maybeCommonType = findWiderCommonType(c.inputTypesForMerging)
677699 maybeCommonType.map { commonType =>
678- var changed = false
679700 val newBranches = c.branches.map { case (condition, value) =>
680- if (value.dataType.sameType(commonType)) {
681- (condition, value)
682- } else {
683- changed = true
684- (condition, Cast (value, commonType))
685- }
686- }
687- val newElseValue = c.elseValue.map { value =>
688- if (value.dataType.sameType(commonType)) {
689- value
690- } else {
691- changed = true
692- Cast (value, commonType)
693- }
701+ (condition, castIfNotSameType(value, commonType))
694702 }
695- if (changed) CaseWhen (newBranches, newElseValue) else c
703+ val newElseValue = c.elseValue.map(castIfNotSameType(_, commonType))
704+ CaseWhen (newBranches, newElseValue)
696705 }.getOrElse(c)
697706 }
698707 }
@@ -705,10 +714,10 @@ object TypeCoercion {
705714 plan : LogicalPlan ): LogicalPlan = plan transformAllExpressions {
706715 case e if ! e.childrenResolved => e
707716 // Find tightest common type for If, if the true value and false value have different types.
708- case i @ If (pred, left, right) if ! i.areInputTypesForMergingEqual =>
717+ case i @ If (pred, left, right) if ! haveSameType(i.inputTypesForMerging) =>
709718 findWiderTypeForTwo(left.dataType, right.dataType).map { widestType =>
710- val newLeft = if (left.dataType.sameType(widestType)) left else Cast (left, widestType)
711- val newRight = if (right.dataType.sameType(widestType)) right else Cast (right, widestType)
719+ val newLeft = castIfNotSameType (left, widestType)
720+ val newRight = castIfNotSameType (right, widestType)
712721 If (pred, newLeft, newRight)
713722 }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged.
714723 case If (Literal (null , NullType ), left, right) =>
0 commit comments