Skip to content

Commit d55c9e5

Browse files
committed
Removes NullTypes.
1 parent 360d124 commit d55c9e5

File tree

2 files changed

+94
-12
lines changed

2 files changed

+94
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ object HiveTypeCoercion {
3636
val typeCoercionRules =
3737
PropagateTypes ::
3838
InConversion ::
39+
RemoveNullTypes ::
3940
WidenTypes ::
4041
PromoteStrings ::
4142
DecimalPrecision ::
@@ -147,6 +148,47 @@ object HiveTypeCoercion {
147148
}
148149
}
149150

151+
/**
152+
* Removes [[NullType]] (from null literals in SQL) from expressions by adding an explicit cast
153+
* into the type the expression supports. This rule is here to avoid handling [[NullType]] in
154+
* every expression implementations.
155+
*
156+
* This works by looking up the expected input types for expressions, and cast [[NullType]]
157+
* into some other specific data type that the expression expects. For example, consider [[Add]],
158+
* an expression that supports any numeric types.
159+
*
160+
* When applying this rule on `Add(NullType, NullType)`, the expression will be converted to
161+
* `Add(DoubleType, DoubleType)`.
162+
*/
163+
object RemoveNullTypes extends Rule[LogicalPlan] {
164+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
165+
case q: LogicalPlan => q transformExpressions {
166+
// Skip nodes who's children have not been resolved yet.
167+
case e if !e.childrenResolved => e
168+
169+
// For binary operators whose input types are NullType, cast them to some specific type.
170+
case b@BinaryOperator(left, right)
171+
if left.dataType == NullType && right.dataType == NullType &&
172+
!b.inputType.acceptsType(NullType) =>
173+
// If both inputs are null type (from null literals), cast the null type into some
174+
// specific type the expression expects, so expressions don't need to handle NullType
175+
val newLeft = Cast(left, b.inputType.defaultConcreteType)
176+
val newRight = Cast(right, b.inputType.defaultConcreteType)
177+
b.makeCopy(Array(newLeft, newRight))
178+
179+
case e: ExpectsInputTypes if e.inputTypes.nonEmpty =>
180+
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
181+
if (in.dataType == NullType && !expected.acceptsType(NullType)) {
182+
Cast(in, expected.defaultConcreteType)
183+
} else {
184+
in
185+
}
186+
}
187+
e.withNewChildren(children)
188+
}
189+
}
190+
}
191+
150192
/**
151193
* Widens numeric types and converts strings to numbers when appropriate.
152194
*
@@ -220,14 +262,6 @@ object HiveTypeCoercion {
220262
// Skip nodes who's children have not been resolved yet.
221263
case e if !e.childrenResolved => e
222264

223-
case b @ BinaryOperator(left, right)
224-
if left.dataType == NullType && right.dataType == NullType =>
225-
// If both inputs are null type (from null literals), cast the null type into some
226-
// specific type the expression expects, so expressions don't need to handle NullType
227-
val newLeft = Cast(left, b.inputType.defaultConcreteType)
228-
val newRight = Cast(right, b.inputType.defaultConcreteType)
229-
b.makeCopy(Array(newLeft, newRight))
230-
231265
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
232266
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType =>
233267
if (b.inputType.acceptsType(commonType)) {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,28 @@ class HiveTypeCoercionSuite extends PlanTest {
194194
Project(Seq(Alias(transformed, "a")()), testRelation))
195195
}
196196

197-
test("null literals handling for binary operators") {
198-
ruleTest(HiveTypeCoercion.WidenTypes,
199-
Add(Literal.create(null, NullType), Literal.create(null, NullType)),
200-
Add(
197+
test("RemoveNullTypes for expressions that define ExpectsInputTypes") {
198+
import HiveTypeCoercionSuite._
199+
200+
ruleTest(HiveTypeCoercion.RemoveNullTypes,
201+
AnyTypeUnaryExpression(Literal.create(null, NullType)),
202+
AnyTypeUnaryExpression(Literal.create(null, NullType)))
203+
204+
ruleTest(HiveTypeCoercion.RemoveNullTypes,
205+
NumericTypeUnaryExpression(Literal.create(null, NullType)),
206+
NumericTypeUnaryExpression(Cast(Literal.create(null, NullType), DoubleType)))
207+
}
208+
209+
test("RemoveNullTypes for binary operators") {
210+
import HiveTypeCoercionSuite._
211+
212+
ruleTest(HiveTypeCoercion.RemoveNullTypes,
213+
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
214+
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)))
215+
216+
ruleTest(HiveTypeCoercion.RemoveNullTypes,
217+
NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
218+
NumericTypeBinaryOperator(
201219
Cast(Literal.create(null, NullType), DoubleType),
202220
Cast(Literal.create(null, NullType), DoubleType)))
203221
}
@@ -310,3 +328,33 @@ class HiveTypeCoercionSuite extends PlanTest {
310328
)
311329
}
312330
}
331+
332+
333+
object HiveTypeCoercionSuite {
334+
335+
case class AnyTypeUnaryExpression(child: Expression)
336+
extends UnaryExpression with ExpectsInputTypes {
337+
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
338+
override def dataType: DataType = NullType
339+
}
340+
341+
case class NumericTypeUnaryExpression(child: Expression)
342+
extends UnaryExpression with ExpectsInputTypes {
343+
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
344+
override def dataType: DataType = NullType
345+
}
346+
347+
case class AnyTypeBinaryOperator(left: Expression, right: Expression)
348+
extends BinaryOperator with ExpectsInputTypes {
349+
override def dataType: DataType = NullType
350+
override def inputType: AbstractDataType = AnyDataType
351+
override def symbol: String = "anytype"
352+
}
353+
354+
case class NumericTypeBinaryOperator(left: Expression, right: Expression)
355+
extends BinaryOperator with ExpectsInputTypes {
356+
override def dataType: DataType = NullType
357+
override def inputType: AbstractDataType = NumericType
358+
override def symbol: String = "numerictype"
359+
}
360+
}

0 commit comments

Comments
 (0)