Skip to content

Commit b045315

Browse files
ueshincloud-fan
authored andcommitted
[SPARK-24734][SQL] Fix type coercions and nullabilities of nested data types of some functions.
## What changes were proposed in this pull request? We have some functions which need to aware the nullabilities of all children, such as `CreateArray`, `CreateMap`, `Concat`, and so on. Currently we add casts to fix the nullabilities, but the casts might be removed during the optimization phase. After the discussion, we decided to not add extra casts for just fixing the nullabilities of the nested types, but handle them by functions themselves. ## How was this patch tested? Modified and added some tests. Author: Takuya UESHIN <[email protected]> Closes #21704 from ueshin/issues/SPARK-24734/concat_containsnull.
1 parent cf97045 commit b045315

File tree

15 files changed

+211
-142
lines changed

15 files changed

+211
-142
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 61 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,17 @@ object TypeCoercion {
184184
}
185185
}
186186

187+
def findCommonTypeDifferentOnlyInNullFlags(types: Seq[DataType]): Option[DataType] = {
188+
if (types.isEmpty) {
189+
None
190+
} else {
191+
types.tail.foldLeft[Option[DataType]](Some(types.head)) {
192+
case (Some(t1), t2) => findCommonTypeDifferentOnlyInNullFlags(t1, t2)
193+
case _ => None
194+
}
195+
}
196+
}
197+
187198
/**
188199
* Case 2 type widening (see the classdoc comment above for TypeCoercion).
189200
*
@@ -259,8 +270,25 @@ object TypeCoercion {
259270
}
260271
}
261272

262-
private def haveSameType(exprs: Seq[Expression]): Boolean =
263-
exprs.map(_.dataType).distinct.length == 1
273+
/**
274+
* Check whether the given types are equal ignoring nullable, containsNull and valueContainsNull.
275+
*/
276+
def haveSameType(types: Seq[DataType]): Boolean = {
277+
if (types.size <= 1) {
278+
true
279+
} else {
280+
val head = types.head
281+
types.tail.forall(_.sameType(head))
282+
}
283+
}
284+
285+
private def castIfNotSameType(expr: Expression, dt: DataType): Expression = {
286+
if (!expr.dataType.sameType(dt)) {
287+
Cast(expr, dt)
288+
} else {
289+
expr
290+
}
291+
}
264292

265293
/**
266294
* Widens numeric types and converts strings to numbers when appropriate.
@@ -525,23 +553,24 @@ object TypeCoercion {
525553
* This ensure that the types for various functions are as expected.
526554
*/
527555
object FunctionArgumentConversion extends TypeCoercionRule {
556+
528557
override protected def coerceTypes(
529558
plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
530559
// Skip nodes who's children have not been resolved yet.
531560
case e if !e.childrenResolved => e
532561

533-
case a @ CreateArray(children) if !haveSameType(children) =>
562+
case a @ CreateArray(children) if !haveSameType(children.map(_.dataType)) =>
534563
val types = children.map(_.dataType)
535564
findWiderCommonType(types) match {
536-
case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType)))
565+
case Some(finalDataType) => CreateArray(children.map(castIfNotSameType(_, finalDataType)))
537566
case None => a
538567
}
539568

540569
case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) &&
541-
!haveSameType(children) =>
570+
!haveSameType(c.inputTypesForMerging) =>
542571
val types = children.map(_.dataType)
543572
findWiderCommonType(types) match {
544-
case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType)))
573+
case Some(finalDataType) => Concat(children.map(castIfNotSameType(_, finalDataType)))
545574
case None => c
546575
}
547576

@@ -553,41 +582,34 @@ object TypeCoercion {
553582
case None => aj
554583
}
555584

556-
case s @ Sequence(_, _, _, timeZoneId) if !haveSameType(s.coercibleChildren) =>
585+
case s @ Sequence(_, _, _, timeZoneId)
586+
if !haveSameType(s.coercibleChildren.map(_.dataType)) =>
557587
val types = s.coercibleChildren.map(_.dataType)
558588
findWiderCommonType(types) match {
559589
case Some(widerDataType) => s.castChildrenTo(widerDataType)
560590
case None => s
561591
}
562592

563593
case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) &&
564-
!haveSameType(children) =>
594+
!haveSameType(m.inputTypesForMerging) =>
565595
val types = children.map(_.dataType)
566596
findWiderCommonType(types) match {
567-
case Some(finalDataType) => MapConcat(children.map(Cast(_, finalDataType)))
597+
case Some(finalDataType) => MapConcat(children.map(castIfNotSameType(_, finalDataType)))
568598
case None => m
569599
}
570600

571601
case m @ CreateMap(children) if m.keys.length == m.values.length &&
572-
(!haveSameType(m.keys) || !haveSameType(m.values)) =>
573-
val newKeys = if (haveSameType(m.keys)) {
574-
m.keys
575-
} else {
576-
val types = m.keys.map(_.dataType)
577-
findWiderCommonType(types) match {
578-
case Some(finalDataType) => m.keys.map(Cast(_, finalDataType))
579-
case None => m.keys
580-
}
602+
(!haveSameType(m.keys.map(_.dataType)) || !haveSameType(m.values.map(_.dataType))) =>
603+
val keyTypes = m.keys.map(_.dataType)
604+
val newKeys = findWiderCommonType(keyTypes) match {
605+
case Some(finalDataType) => m.keys.map(castIfNotSameType(_, finalDataType))
606+
case None => m.keys
581607
}
582608

583-
val newValues = if (haveSameType(m.values)) {
584-
m.values
585-
} else {
586-
val types = m.values.map(_.dataType)
587-
findWiderCommonType(types) match {
588-
case Some(finalDataType) => m.values.map(Cast(_, finalDataType))
589-
case None => m.values
590-
}
609+
val valueTypes = m.values.map(_.dataType)
610+
val newValues = findWiderCommonType(valueTypes) match {
611+
case Some(finalDataType) => m.values.map(castIfNotSameType(_, finalDataType))
612+
case None => m.values
591613
}
592614

593615
CreateMap(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })
@@ -610,27 +632,27 @@ object TypeCoercion {
610632
// Coalesce should return the first non-null value, which could be any column
611633
// from the list. So we need to make sure the return type is deterministic and
612634
// compatible with every child column.
613-
case c @ Coalesce(es) if !haveSameType(es) =>
635+
case c @ Coalesce(es) if !haveSameType(c.inputTypesForMerging) =>
614636
val types = es.map(_.dataType)
615637
findWiderCommonType(types) match {
616-
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
638+
case Some(finalDataType) => Coalesce(es.map(castIfNotSameType(_, finalDataType)))
617639
case None => c
618640
}
619641

620642
// When finding wider type for `Greatest` and `Least`, we should handle decimal types even if
621643
// we need to truncate, but we should not promote one side to string if the other side is
622644
// string.g
623-
case g @ Greatest(children) if !haveSameType(children) =>
645+
case g @ Greatest(children) if !haveSameType(g.inputTypesForMerging) =>
624646
val types = children.map(_.dataType)
625647
findWiderTypeWithoutStringPromotion(types) match {
626-
case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType)))
648+
case Some(finalDataType) => Greatest(children.map(castIfNotSameType(_, finalDataType)))
627649
case None => g
628650
}
629651

630-
case l @ Least(children) if !haveSameType(children) =>
652+
case l @ Least(children) if !haveSameType(l.inputTypesForMerging) =>
631653
val types = children.map(_.dataType)
632654
findWiderTypeWithoutStringPromotion(types) match {
633-
case Some(finalDataType) => Least(children.map(Cast(_, finalDataType)))
655+
case Some(finalDataType) => Least(children.map(castIfNotSameType(_, finalDataType)))
634656
case None => l
635657
}
636658

@@ -672,27 +694,14 @@ object TypeCoercion {
672694
object CaseWhenCoercion extends TypeCoercionRule {
673695
override protected def coerceTypes(
674696
plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
675-
case c: CaseWhen if c.childrenResolved && !c.areInputTypesForMergingEqual =>
697+
case c: CaseWhen if c.childrenResolved && !haveSameType(c.inputTypesForMerging) =>
676698
val maybeCommonType = findWiderCommonType(c.inputTypesForMerging)
677699
maybeCommonType.map { commonType =>
678-
var changed = false
679700
val newBranches = c.branches.map { case (condition, value) =>
680-
if (value.dataType.sameType(commonType)) {
681-
(condition, value)
682-
} else {
683-
changed = true
684-
(condition, Cast(value, commonType))
685-
}
686-
}
687-
val newElseValue = c.elseValue.map { value =>
688-
if (value.dataType.sameType(commonType)) {
689-
value
690-
} else {
691-
changed = true
692-
Cast(value, commonType)
693-
}
701+
(condition, castIfNotSameType(value, commonType))
694702
}
695-
if (changed) CaseWhen(newBranches, newElseValue) else c
703+
val newElseValue = c.elseValue.map(castIfNotSameType(_, commonType))
704+
CaseWhen(newBranches, newElseValue)
696705
}.getOrElse(c)
697706
}
698707
}
@@ -705,10 +714,10 @@ object TypeCoercion {
705714
plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
706715
case e if !e.childrenResolved => e
707716
// Find tightest common type for If, if the true value and false value have different types.
708-
case i @ If(pred, left, right) if !i.areInputTypesForMergingEqual =>
717+
case i @ If(pred, left, right) if !haveSameType(i.inputTypesForMerging) =>
709718
findWiderTypeForTwo(left.dataType, right.dataType).map { widestType =>
710-
val newLeft = if (left.dataType.sameType(widestType)) left else Cast(left, widestType)
711-
val newRight = if (right.dataType.sameType(widestType)) right else Cast(right, widestType)
719+
val newLeft = castIfNotSameType(left, widestType)
720+
val newRight = castIfNotSameType(right, widestType)
712721
If(pred, newLeft, newRight)
713722
}.getOrElse(i) // If there is no applicable conversion, leave expression unchanged.
714723
case If(Literal(null, NullType), left, right) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -709,22 +709,12 @@ trait ComplexTypeMergingExpression extends Expression {
709709
@transient
710710
lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType)
711711

712-
/**
713-
* A method determining whether the input types are equal ignoring nullable, containsNull and
714-
* valueContainsNull flags and thus convenient for resolution of the final data type.
715-
*/
716-
def areInputTypesForMergingEqual: Boolean = {
717-
inputTypesForMerging.length <= 1 || inputTypesForMerging.sliding(2, 1).forall {
718-
case Seq(dt1, dt2) => dt1.sameType(dt2)
719-
}
720-
}
721-
722712
override def dataType: DataType = {
723713
require(
724714
inputTypesForMerging.nonEmpty,
725715
"The collection of input data types must not be empty.")
726716
require(
727-
areInputTypesForMergingEqual,
717+
TypeCoercion.haveSameType(inputTypesForMerging),
728718
"All input types must be the same except nullable, containsNull, valueContainsNull flags.")
729719
inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get)
730720
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
21-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
21+
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
2222
import org.apache.spark.sql.catalyst.expressions.codegen._
2323
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2424
import org.apache.spark.sql.catalyst.util.TypeUtils
@@ -514,7 +514,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
514514
> SELECT _FUNC_(10, 9, 2, 4, 3);
515515
2
516516
""")
517-
case class Least(children: Seq[Expression]) extends Expression {
517+
case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression {
518518

519519
override def nullable: Boolean = children.forall(_.nullable)
520520
override def foldable: Boolean = children.forall(_.foldable)
@@ -525,7 +525,7 @@ case class Least(children: Seq[Expression]) extends Expression {
525525
if (children.length <= 1) {
526526
TypeCheckResult.TypeCheckFailure(
527527
s"input to function $prettyName requires at least two arguments")
528-
} else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
528+
} else if (!TypeCoercion.haveSameType(inputTypesForMerging)) {
529529
TypeCheckResult.TypeCheckFailure(
530530
s"The expressions should all have the same type," +
531531
s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).")
@@ -534,8 +534,6 @@ case class Least(children: Seq[Expression]) extends Expression {
534534
}
535535
}
536536

537-
override def dataType: DataType = children.head.dataType
538-
539537
override def eval(input: InternalRow): Any = {
540538
children.foldLeft[Any](null)((r, c) => {
541539
val evalc = c.eval(input)
@@ -589,7 +587,7 @@ case class Least(children: Seq[Expression]) extends Expression {
589587
> SELECT _FUNC_(10, 9, 2, 4, 3);
590588
10
591589
""")
592-
case class Greatest(children: Seq[Expression]) extends Expression {
590+
case class Greatest(children: Seq[Expression]) extends ComplexTypeMergingExpression {
593591

594592
override def nullable: Boolean = children.forall(_.nullable)
595593
override def foldable: Boolean = children.forall(_.foldable)
@@ -600,7 +598,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
600598
if (children.length <= 1) {
601599
TypeCheckResult.TypeCheckFailure(
602600
s"input to function $prettyName requires at least two arguments")
603-
} else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
601+
} else if (!TypeCoercion.haveSameType(inputTypesForMerging)) {
604602
TypeCheckResult.TypeCheckFailure(
605603
s"The expressions should all have the same type," +
606604
s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).")
@@ -609,8 +607,6 @@ case class Greatest(children: Seq[Expression]) extends Expression {
609607
}
610608
}
611609

612-
override def dataType: DataType = children.head.dataType
613-
614610
override def eval(input: InternalRow): Any = {
615611
children.foldLeft[Any](null)((r, c) => {
616612
val evalc = c.eval(input)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
507507
> SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd'));
508508
[[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]]
509509
""", since = "2.4.0")
510-
case class MapConcat(children: Seq[Expression]) extends Expression {
510+
case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpression {
511511

512512
override def checkInputDataTypes(): TypeCheckResult = {
513513
var funcName = s"function $prettyName"
@@ -521,14 +521,10 @@ case class MapConcat(children: Seq[Expression]) extends Expression {
521521
}
522522

523523
override def dataType: MapType = {
524-
val dt = children.map(_.dataType.asInstanceOf[MapType]).headOption
525-
.getOrElse(MapType(StringType, StringType))
526-
val valueContainsNull = children.map(_.dataType.asInstanceOf[MapType])
527-
.exists(_.valueContainsNull)
528-
if (dt.valueContainsNull != valueContainsNull) {
529-
dt.copy(valueContainsNull = valueContainsNull)
524+
if (children.isEmpty) {
525+
MapType(StringType, StringType)
530526
} else {
531-
dt
527+
super.dataType.asInstanceOf[MapType]
532528
}
533529
}
534530

@@ -2211,7 +2207,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
22112207
> SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
22122208
| [1,2,3,4,5,6]
22132209
""")
2214-
case class Concat(children: Seq[Expression]) extends Expression {
2210+
case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression {
22152211

22162212
private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
22172213

@@ -2232,7 +2228,13 @@ case class Concat(children: Seq[Expression]) extends Expression {
22322228
}
22332229
}
22342230

2235-
override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType)
2231+
override def dataType: DataType = {
2232+
if (children.isEmpty) {
2233+
StringType
2234+
} else {
2235+
super.dataType
2236+
}
2237+
}
22362238

22372239
lazy val javaType: String = CodeGenerator.javaType(dataType)
22382240

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
21+
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
2122
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
22-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2323
import org.apache.spark.sql.catalyst.expressions.codegen._
2424
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2525
import org.apache.spark.sql.catalyst.util._
@@ -48,7 +48,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
4848

4949
override def dataType: ArrayType = {
5050
ArrayType(
51-
children.headOption.map(_.dataType).getOrElse(StringType),
51+
TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(children.map(_.dataType))
52+
.getOrElse(StringType),
5253
containsNull = children.exists(_.nullable))
5354
}
5455

@@ -179,11 +180,11 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
179180
if (children.size % 2 != 0) {
180181
TypeCheckResult.TypeCheckFailure(
181182
s"$prettyName expects a positive even number of arguments.")
182-
} else if (keys.map(_.dataType).distinct.length > 1) {
183+
} else if (!TypeCoercion.haveSameType(keys.map(_.dataType))) {
183184
TypeCheckResult.TypeCheckFailure(
184185
"The given keys of function map should all be the same type, but they are " +
185186
keys.map(_.dataType.simpleString).mkString("[", ", ", "]"))
186-
} else if (values.map(_.dataType).distinct.length > 1) {
187+
} else if (!TypeCoercion.haveSameType(values.map(_.dataType))) {
187188
TypeCheckResult.TypeCheckFailure(
188189
"The given values of function map should all be the same type, but they are " +
189190
values.map(_.dataType.simpleString).mkString("[", ", ", "]"))
@@ -194,8 +195,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
194195

195196
override def dataType: DataType = {
196197
MapType(
197-
keyType = keys.headOption.map(_.dataType).getOrElse(StringType),
198-
valueType = values.headOption.map(_.dataType).getOrElse(StringType),
198+
keyType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(keys.map(_.dataType))
199+
.getOrElse(StringType),
200+
valueType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(values.map(_.dataType))
201+
.getOrElse(StringType),
199202
valueContainsNull = values.exists(_.nullable))
200203
}
201204

0 commit comments

Comments
 (0)