diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 82d689477080d..f7fe467cea830 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -144,7 +144,7 @@ object TimeWindow { case class PreciseTimestampConversion( child: Expression, fromType: DataType, - toType: DataType) extends UnaryExpression with ExpectsInputTypes { + toType: DataType) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(fromType) override def dataType: DataType = toType override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 7b819db32e425..342b14eaa3390 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -127,7 +127,8 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme > SELECT _FUNC_ 0; -1 """) -case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class BitwiseNot(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) @@ -164,7 +165,8 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp 0 """, since = "3.0.0") -case class BitwiseCount(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class BitwiseCount(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegralType, BooleanType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4fd68dcfe5156..b32e9ee05f1ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -141,7 +141,7 @@ object Size { """, group = "map_funcs") case class MapKeys(child: Expression) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(MapType) @@ -332,7 +332,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI """, group = "map_funcs") case class MapValues(child: Expression) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(MapType) @@ -361,7 +361,8 @@ case class MapValues(child: Expression) """, group = "map_funcs", since = "3.0.0") -case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class MapEntries(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(MapType) @@ -649,7 +650,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres """, group = "map_funcs", since = "2.4.0") -case class MapFromEntries(child: Expression) extends UnaryExpression { +case class MapFromEntries(child: Expression) extends UnaryExpression with NullIntolerant { @transient private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match { @@ -873,7 +874,7 @@ object ArraySortLike { group = "array_funcs") // scalastyle:on line.size.limit case class SortArray(base: Expression, ascendingOrder: Expression) - extends BinaryExpression with ArraySortLike { + extends BinaryExpression with ArraySortLike with NullIntolerant { def this(e: Expression) = this(e, Literal(true)) @@ -1017,7 +1018,8 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) Reverse logic for arrays is available since 2.4.0. """ ) -case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Reverse(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { // Input types are utilized by type coercion in ImplicitTypeCasts. override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType)) @@ -1086,7 +1088,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI """, group = "array_funcs") case class ArrayContains(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = BooleanType @@ -1185,7 +1187,7 @@ case class ArrayContains(left: Expression, right: Expression) since = "2.4.0") // scalastyle:off line.size.limit case class ArraysOverlap(left: Expression, right: Expression) - extends BinaryArrayExpressionWithImplicitCast { + extends BinaryArrayExpressionWithImplicitCast with NullIntolerant { override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => @@ -1410,7 +1412,7 @@ case class ArraysOverlap(left: Expression, right: Expression) since = "2.4.0") // scalastyle:on line.size.limit case class Slice(x: Expression, start: Expression, length: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = x.dataType @@ -1688,7 +1690,8 @@ case class ArrayJoin( """, group = "array_funcs", since = "2.4.0") -case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class ArrayMin(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def nullable: Boolean = true @@ -1755,7 +1758,8 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast """, group = "array_funcs", since = "2.4.0") -case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class ArrayMax(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def nullable: Boolean = true @@ -1831,7 +1835,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast group = "array_funcs", since = "2.4.0") case class ArrayPosition(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(right.dataType) @@ -1909,7 +1913,7 @@ case class ArrayPosition(left: Expression, right: Expression) """, since = "2.4.0") case class ElementAt(left: Expression, right: Expression) - extends GetMapValueUtil with GetArrayItemUtil { + extends GetMapValueUtil with GetArrayItemUtil with NullIntolerant { @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType @@ -2245,7 +2249,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio """, group = "array_funcs", since = "2.4.0") -case class Flatten(child: Expression) extends UnaryExpression { +case class Flatten(child: Expression) extends UnaryExpression with NullIntolerant { private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] @@ -2884,7 +2888,7 @@ case class ArrayRepeat(left: Expression, right: Expression) group = "array_funcs", since = "2.4.0") case class ArrayRemove(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = left.dataType @@ -3081,7 +3085,7 @@ trait ArraySetLike { group = "array_funcs", since = "2.4.0") case class ArrayDistinct(child: Expression) - extends UnaryExpression with ArraySetLike with ExpectsInputTypes { + extends UnaryExpression with ArraySetLike with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) @@ -3219,7 +3223,8 @@ case class ArrayDistinct(child: Expression) /** * Will become common base class for [[ArrayUnion]], [[ArrayIntersect]], and [[ArrayExcept]]. */ -trait ArrayBinaryLike extends BinaryArrayExpressionWithImplicitCast with ArraySetLike { +trait ArrayBinaryLike + extends BinaryArrayExpressionWithImplicitCast with ArraySetLike with NullIntolerant { override protected def dt: DataType = dataType override protected def et: DataType = elementType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 5212ef3930bc9..1b4a705e804f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -255,7 +255,7 @@ object CreateMap { {1.0:"2",3.0:"4"} """, since = "2.4.0") case class MapFromArrays(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) @@ -476,7 +476,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { since = "2.0.1") // scalastyle:on line.size.limit case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: Expression) - extends TernaryExpression with ExpectsInputTypes { + extends TernaryExpression with ExpectsInputTypes with NullIntolerant { def this(child: Expression, pairDelim: Expression) = { this(child, pairDelim, Literal(":")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 5140db90c5954..f9ccf3c8c811f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -211,7 +211,8 @@ case class StructsToCsv( options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes + with NullIntolerant { override def nullable: Boolean = true def this(options: Map[String, String], child: Expression) = this(options, child, None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 7dc008a2e5df8..4f3db1b8a57ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -198,7 +198,7 @@ case class CurrentBatchTimestamp( group = "datetime_funcs", since = "1.5.0") case class DateAdd(startDate: Expression, days: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with NullIntolerant { override def left: Expression = startDate override def right: Expression = days @@ -234,7 +234,7 @@ case class DateAdd(startDate: Expression, days: Expression) group = "datetime_funcs", since = "1.5.0") case class DateSub(startDate: Expression, days: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with NullIntolerant { override def left: Expression = startDate override def right: Expression = days @@ -266,7 +266,8 @@ case class DateSub(startDate: Expression, days: Expression) group = "datetime_funcs", since = "1.5.0") case class Hour(child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(child: Expression) = this(child, None) @@ -298,7 +299,8 @@ case class Hour(child: Expression, timeZoneId: Option[String] = None) group = "datetime_funcs", since = "1.5.0") case class Minute(child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(child: Expression) = this(child, None) @@ -330,7 +332,8 @@ case class Minute(child: Expression, timeZoneId: Option[String] = None) group = "datetime_funcs", since = "1.5.0") case class Second(child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(child: Expression) = this(child, None) @@ -353,7 +356,8 @@ case class Second(child: Expression, timeZoneId: Option[String] = None) } case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(child: Expression) = this(child, None) @@ -385,7 +389,8 @@ case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = No """, group = "datetime_funcs", since = "1.5.0") -case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class DayOfYear(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -402,7 +407,7 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas } abstract class NumberToTimestampBase extends UnaryExpression - with ExpectsInputTypes { + with ExpectsInputTypes with NullIntolerant { protected def upScaleFactor: Long @@ -487,7 +492,8 @@ case class MicrosToTimestamp(child: Expression) """, group = "datetime_funcs", since = "1.5.0") -case class Year(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Year(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -503,7 +509,8 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu } } -case class YearOfWeek(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class YearOfWeek(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -528,7 +535,8 @@ case class YearOfWeek(child: Expression) extends UnaryExpression with ImplicitCa """, group = "datetime_funcs", since = "1.5.0") -case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Quarter(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -553,7 +561,8 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI """, group = "datetime_funcs", since = "1.5.0") -case class Month(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Month(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -577,7 +586,8 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp 30 """, since = "1.5.0") -case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class DayOfMonth(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -647,7 +657,7 @@ case class WeekDay(child: Expression) extends DayWeek { } } -abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes { +abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -665,7 +675,8 @@ abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes { group = "datetime_funcs", since = "1.5.0") // scalastyle:on line.size.limit -case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class WeekOfYear(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -704,7 +715,8 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa since = "1.5.0") // scalastyle:on line.size.limit case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(left: Expression, right: Expression) = this(left, right, None) @@ -1154,7 +1166,8 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ """, group = "datetime_funcs", since = "1.5.0") -case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class LastDay(startDate: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def child: Expression = startDate override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -1192,7 +1205,7 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC since = "1.5.0") // scalastyle:on line.size.limit case class NextDay(startDate: Expression, dayOfWeek: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = startDate override def right: Expression = dayOfWeek @@ -1248,7 +1261,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) * Adds an interval to timestamp. */ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes { + extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes with NullIntolerant { def this(start: Expression, interval: Expression) = this(start, interval, None) @@ -1306,7 +1319,7 @@ case class DateAddInterval( interval: Expression, timeZoneId: Option[String] = None, ansiEnabled: Boolean = SQLConf.get.ansiEnabled) - extends BinaryExpression with ExpectsInputTypes with TimeZoneAwareExpression { + extends BinaryExpression with ExpectsInputTypes with TimeZoneAwareExpression with NullIntolerant { override def left: Expression = start override def right: Expression = interval @@ -1380,7 +1393,7 @@ case class DateAddInterval( since = "1.5.0") // scalastyle:on line.size.limit case class FromUTCTimestamp(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) override def dataType: DataType = TimestampType @@ -1440,7 +1453,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) since = "1.5.0") // scalastyle:on line.size.limit case class AddMonths(startDate: Expression, numMonths: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = startDate override def right: Expression = numMonths @@ -1494,7 +1507,8 @@ case class MonthsBetween( date2: Expression, roundOff: Expression, timeZoneId: Option[String] = None) - extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(date1: Expression, date2: Expression) = this(date1, date2, Literal.TrueLiteral, None) @@ -1552,7 +1566,7 @@ case class MonthsBetween( since = "1.5.0") // scalastyle:on line.size.limit case class ToUTCTimestamp(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) override def dataType: DataType = TimestampType @@ -1906,7 +1920,7 @@ case class TruncTimestamp( group = "datetime_funcs", since = "1.5.0") case class DateDiff(endDate: Expression, startDate: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = endDate override def right: Expression = startDate @@ -1960,7 +1974,7 @@ private case class GetTimestamp( group = "datetime_funcs", since = "3.0.0") case class MakeDate(year: Expression, month: Expression, day: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { override def children: Seq[Expression] = Seq(year, month, day) override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, IntegerType, IntegerType) @@ -2031,7 +2045,8 @@ case class MakeTimestamp( sec: Expression, timezone: Option[Expression] = None, timeZoneId: Option[String] = None) - extends SeptenaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends SeptenaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this( year: Expression, @@ -2307,7 +2322,7 @@ case class Extract(field: Expression, source: Expression, child: Expression) * between the given timestamps. */ case class SubtractTimestamps(endTimestamp: Expression, startTimestamp: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with NullIntolerant { override def left: Expression = endTimestamp override def right: Expression = startTimestamp @@ -2328,7 +2343,7 @@ case class SubtractTimestamps(endTimestamp: Expression, startTimestamp: Expressi * Returns the interval from the `left` date (inclusive) to the `right` date (exclusive). */ case class SubtractDates(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType, DateType) override def dataType: DataType = CalendarIntervalType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 9014ebfe2f96a..c2c70b2ab08e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types._ * Note: this expression is internal and created only by the optimizer, * we don't need to do type check for it. */ -case class UnscaledValue(child: Expression) extends UnaryExpression { +case class UnscaledValue(child: Expression) extends UnaryExpression with NullIntolerant { override def dataType: DataType = LongType override def toString: String = s"UnscaledValue($child)" @@ -49,7 +49,7 @@ case class MakeDecimal( child: Expression, precision: Int, scale: Int, - nullOnOverflow: Boolean) extends UnaryExpression { + nullOnOverflow: Boolean) extends UnaryExpression with NullIntolerant { def this(child: Expression, precision: Int, scale: Int) = { this(child, precision, scale, !SQLConf.get.ansiEnabled) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 4c8c58ae232f4..5e21b58f070ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -53,7 +53,8 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} > SELECT _FUNC_('Spark'); 8cde774d6f7333752ed72cacddb05126 """) -case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Md5(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType @@ -89,7 +90,7 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput """) // scalastyle:on line.size.limit case class Sha2(left: Expression, right: Expression) - extends BinaryExpression with Serializable with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def dataType: DataType = StringType override def nullable: Boolean = true @@ -160,7 +161,8 @@ case class Sha2(left: Expression, right: Expression) > SELECT _FUNC_('Spark'); 85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c """) -case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Sha1(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType @@ -187,7 +189,8 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu > SELECT _FUNC_('Spark'); 1557323817 """) -case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Crc32(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 1a569a7b89fe1..baab224691bc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -31,7 +31,7 @@ abstract class ExtractIntervalPart( val dataType: DataType, func: CalendarInterval => Any, funcName: String) - extends UnaryExpression with ExpectsInputTypes with Serializable { + extends UnaryExpression with ExpectsInputTypes with NullIntolerant with Serializable { override def inputTypes: Seq[AbstractDataType] = Seq(CalendarIntervalType) @@ -82,7 +82,7 @@ object ExtractIntervalPart { abstract class IntervalNumOperation( interval: Expression, num: Expression) - extends BinaryExpression with ImplicitCastInputTypes with Serializable { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def left: Expression = interval override def right: Expression = num @@ -160,7 +160,7 @@ case class MakeInterval( hours: Expression, mins: Expression, secs: Expression) - extends SeptenaryExpression with ImplicitCastInputTypes { + extends SeptenaryExpression with ImplicitCastInputTypes with NullIntolerant { def this( years: Expression, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 205e5271517c3..f4568f860ac0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -519,7 +519,8 @@ case class JsonToStructs( options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes + with NullIntolerant { // The JSON input data might be missing certain fields. We force the nullability // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder @@ -638,7 +639,8 @@ case class StructsToJson( options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback + with ExpectsInputTypes with NullIntolerant { override def nullable: Boolean = true def this(options: Map[String, String], child: Expression) = this(options, child, None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 66e6334e3a450..8806fc68f1306 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -57,7 +57,7 @@ abstract class LeafMathExpression(c: Double, name: String) * @param name The short name of the function */ abstract class UnaryMathExpression(val f: Double => Double, name: String) - extends UnaryExpression with Serializable with ImplicitCastInputTypes { + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType) override def dataType: DataType = DoubleType @@ -111,7 +111,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String) * @param name The short name of the function */ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryExpression with Serializable with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) @@ -324,7 +324,7 @@ case class Acosh(child: Expression) -16 """) case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr) override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) @@ -452,7 +452,8 @@ object Factorial { > SELECT _FUNC_(5); 120 """) -case class Factorial(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Factorial(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -732,7 +733,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia """) // scalastyle:on line.size.limit case class Bin(child: Expression) - extends UnaryExpression with Serializable with ImplicitCastInputTypes { + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def inputTypes: Seq[DataType] = Seq(LongType) override def dataType: DataType = StringType @@ -831,7 +832,8 @@ object Hex { > SELECT _FUNC_('Spark SQL'); 537061726B2053514C """) -case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Hex(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(LongType, BinaryType, StringType)) @@ -866,7 +868,8 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput > SELECT decode(_FUNC_('537061726B2053514C'), 'UTF-8'); Spark SQL """) -case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Unhex(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(StringType) @@ -952,7 +955,7 @@ case class Pow(left: Expression, right: Expression) 4 """) case class ShiftLeft(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) @@ -986,7 +989,7 @@ case class ShiftLeft(left: Expression, right: Expression) 2 """) case class ShiftRight(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) @@ -1020,7 +1023,7 @@ case class ShiftRight(left: Expression, right: Expression) 2 """) case class ShiftRightUnsigned(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 3f60ca388a807..28924fac48eef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -283,7 +283,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress """, since = "1.5.0") case class StringSplit(str: Expression, regex: Expression, limit: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = ArrayType(StringType) override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) @@ -325,7 +325,7 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) since = "1.5.0") // scalastyle:on line.size.limit case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { // last regex in string, we will update the pattern iff regexp value changed. @transient private var lastRegex: UTF8String = _ @@ -433,7 +433,7 @@ object RegExpExtract { """, since = "1.5.0") case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(s: Expression, r: Expression) = this(s, r, Literal(1)) // last regex in string, we will update the pattern iff regexp value changed. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 0b9fb8f85fe3c..7a8ab17c13f38 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -334,7 +334,7 @@ trait String2StringExpression extends ImplicitCastInputTypes { """, since = "1.0.1") case class Upper(child: Expression) - extends UnaryExpression with String2StringExpression { + extends UnaryExpression with String2StringExpression with NullIntolerant { // scalastyle:off caselocale override def convert(v: UTF8String): UTF8String = v.toUpperCase @@ -356,7 +356,8 @@ case class Upper(child: Expression) sparksql """, since = "1.0.1") -case class Lower(child: Expression) extends UnaryExpression with String2StringExpression { +case class Lower(child: Expression) + extends UnaryExpression with String2StringExpression with NullIntolerant { // scalastyle:off caselocale override def convert(v: UTF8String): UTF8String = v.toLowerCase @@ -432,7 +433,7 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate since = "2.3.0") // scalastyle:on line.size.limit case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(srcExpr: Expression, searchExpr: Expression) = { this(srcExpr, searchExpr, Literal("")) @@ -598,7 +599,7 @@ object StringTranslate { since = "1.5.0") // scalastyle:on line.size.limit case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { @transient private var lastMatching: UTF8String = _ @transient private var lastReplace: UTF8String = _ @@ -663,7 +664,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac since = "1.5.0") // scalastyle:on line.size.limit case class FindInSet(left: Expression, right: Expression) extends BinaryExpression - with ImplicitCastInputTypes { + with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) @@ -1035,7 +1036,7 @@ case class StringTrimRight( since = "1.5.0") // scalastyle:on line.size.limit case class StringInstr(str: Expression, substr: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = str override def right: Expression = substr @@ -1077,7 +1078,7 @@ case class StringInstr(str: Expression, substr: Expression) since = "1.5.0") // scalastyle:on line.size.limit case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) @@ -1205,7 +1206,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) """, since = "1.5.0") case class StringLPad(str: Expression, len: Expression, pad: Expression = Literal(" ")) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(str: Expression, len: Expression) = { this(str, len, Literal(" ")) @@ -1246,7 +1247,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera """, since = "1.5.0") case class StringRPad(str: Expression, len: Expression, pad: Expression = Literal(" ")) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(str: Expression, len: Expression) = { this(str, len, Literal(" ")) @@ -1536,7 +1537,8 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC Spark Sql """, since = "1.5.0") -case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class InitCap(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[DataType] = Seq(StringType) override def dataType: DataType = StringType @@ -1563,7 +1565,7 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastI """, since = "1.5.0") case class StringRepeat(str: Expression, times: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = str override def right: Expression = times @@ -1593,7 +1595,7 @@ case class StringRepeat(str: Expression, times: Expression) """, since = "1.5.0") case class StringSpace(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes { + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -1738,7 +1740,8 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run """, since = "1.5.0") // scalastyle:on line.size.limit -case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Length(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -1766,7 +1769,8 @@ case class Length(child: Expression) extends UnaryExpression with ImplicitCastIn 72 """, since = "2.3.0") -case class BitLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class BitLength(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -1797,7 +1801,8 @@ case class BitLength(child: Expression) extends UnaryExpression with ImplicitCas 9 """, since = "2.3.0") -case class OctetLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class OctetLength(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -1828,7 +1833,7 @@ case class OctetLength(child: Expression) extends UnaryExpression with ImplicitC """, since = "1.5.0") case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression - with ImplicitCastInputTypes { + with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) @@ -1853,7 +1858,8 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres M460 """, since = "1.5.0") -case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class SoundEx(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def dataType: DataType = StringType @@ -1879,7 +1885,8 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT 50 """, since = "1.5.0") -case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Ascii(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -1921,7 +1928,8 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp """, since = "2.3.0") // scalastyle:on line.size.limit -case class Chr(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Chr(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(LongType) @@ -1964,7 +1972,8 @@ case class Chr(child: Expression) extends UnaryExpression with ImplicitCastInput U3BhcmsgU1FM """, since = "1.5.0") -case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Base64(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(BinaryType) @@ -1992,7 +2001,8 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn Spark SQL """, since = "1.5.0") -case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class UnBase64(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = BinaryType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -2024,7 +2034,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast since = "1.5.0") // scalastyle:on line.size.limit case class Decode(bin: Expression, charset: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = bin override def right: Expression = charset @@ -2064,7 +2074,7 @@ case class Decode(bin: Expression, charset: Expression) since = "1.5.0") // scalastyle:on line.size.limit case class Encode(value: Expression, charset: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = value override def right: Expression = charset @@ -2108,7 +2118,7 @@ case class Encode(value: Expression, charset: Expression) """, since = "1.5.0") case class FormatNumber(x: Expression, d: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with NullIntolerant { override def left: Expression = x override def right: Expression = d diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index 55e06cb9e8471..e08a10ecac71c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -30,7 +30,8 @@ import org.apache.spark.unsafe.types.UTF8String * * This is not the world's most efficient implementation due to type conversion, but works. */ -abstract class XPathExtract extends BinaryExpression with ExpectsInputTypes with CodegenFallback { +abstract class XPathExtract + extends BinaryExpression with ExpectsInputTypes with CodegenFallback with NullIntolerant { override def left: Expression = xml override def right: Expression = path diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index e18514c6f93f9..53f9757750735 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -20,11 +20,12 @@ package org.apache.spark.sql.expressions import scala.collection.parallel.immutable.ParVector import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.expressions.ExpressionInfo +import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.HiveResult.hiveResultString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.Utils class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { @@ -156,4 +157,38 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { } } } + + test("Check whether SQL expressions should extend NullIntolerant") { + // Only check expressions extended from these expressions because these expressions are + // NullIntolerant by default. + val exprTypesToCheck = Seq(classOf[UnaryExpression], classOf[BinaryExpression], + classOf[TernaryExpression], classOf[QuaternaryExpression], classOf[SeptenaryExpression]) + + // Do not check these expressions, because these expressions extend NullIntolerant + // and override the eval method to avoid evaluating input1 if input2 is 0. + val ignoreSet = Set(classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod]) + + val candidateExprsToCheck = spark.sessionState.functionRegistry.listFunction() + .map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName) + .filterNot(c => ignoreSet.exists(_.getName.equals(c))) + .map(name => Utils.classForName(name)) + .filterNot(classOf[NonSQLExpression].isAssignableFrom) + + exprTypesToCheck.foreach { superClass => + candidateExprsToCheck.filter(superClass.isAssignableFrom).foreach { clazz => + val isEvalOverrode = clazz.getMethod("eval", classOf[InternalRow]) != + superClass.getMethod("eval", classOf[InternalRow]) + val isNullIntolerantMixedIn = classOf[NullIntolerant].isAssignableFrom(clazz) + if (isEvalOverrode && isNullIntolerantMixedIn) { + fail(s"${clazz.getName} should not extend ${classOf[NullIntolerant].getSimpleName}, " + + s"or add ${clazz.getName} in the ignoreSet of this test.") + } else if (!isEvalOverrode && !isNullIntolerantMixedIn) { + fail(s"${clazz.getName} should extend ${classOf[NullIntolerant].getSimpleName}.") + } else { + assert((!isEvalOverrode && isNullIntolerantMixedIn) || + (isEvalOverrode && !isNullIntolerantMixedIn)) + } + } + } + } }