Skip to content

Commit aa7790e

Browse files
committed
Moved RemoveNullTypes into ImplicitTypeCasts.
1 parent 438ea07 commit aa7790e

File tree

2 files changed

+40
-67
lines changed

2 files changed

+40
-67
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 34 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ object HiveTypeCoercion {
3636
val typeCoercionRules =
3737
PropagateTypes ::
3838
InConversion ::
39-
RemoveNullTypes ::
4039
WidenTypes ::
4140
PromoteStrings ::
4241
DecimalPrecision ::
@@ -148,47 +147,6 @@ object HiveTypeCoercion {
148147
}
149148
}
150149

151-
/**
152-
* Removes [[NullType]] (from null literals in SQL) from expressions by adding an explicit cast
153-
* into the type the expression supports. This rule is here to avoid handling [[NullType]] in
154-
* every expression implementations.
155-
*
156-
* This works by looking up the expected input types for expressions, and cast [[NullType]]
157-
* into some other specific data type that the expression expects. For example, consider [[Add]],
158-
* an expression that supports any numeric types.
159-
*
160-
* When applying this rule on `Add(NullType, NullType)`, the expression will be converted to
161-
* `Add(DoubleType, DoubleType)`.
162-
*/
163-
object RemoveNullTypes extends Rule[LogicalPlan] {
164-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
165-
case q: LogicalPlan => q transformExpressions {
166-
// Skip nodes who's children have not been resolved yet.
167-
case e if !e.childrenResolved => e
168-
169-
// For binary operators whose input types are NullType, cast them to some specific type.
170-
case b @ BinaryOperator(left, right)
171-
if left.dataType == NullType && right.dataType == NullType &&
172-
!b.inputType.acceptsType(NullType) =>
173-
// If both inputs are null type (from null literals), cast the null type into some
174-
// specific type the expression expects, so expressions don't need to handle NullType
175-
val newLeft = Cast(left, b.inputType.defaultConcreteType)
176-
val newRight = Cast(right, b.inputType.defaultConcreteType)
177-
b.makeCopy(Array(newLeft, newRight))
178-
179-
case e: ExpectsInputTypes if e.inputTypes.nonEmpty =>
180-
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
181-
if (in.dataType == NullType && !expected.acceptsType(NullType)) {
182-
Cast(in, expected.defaultConcreteType)
183-
} else {
184-
in
185-
}
186-
}
187-
e.withNewChildren(children)
188-
}
189-
}
190-
}
191-
192150
/**
193151
* Widens numeric types and converts strings to numbers when appropriate.
194152
*
@@ -256,25 +214,6 @@ object HiveTypeCoercion {
256214
}
257215

258216
Union(newLeft, newRight)
259-
260-
// Also widen types for BinaryOperator.
261-
case q: LogicalPlan => q transformExpressions {
262-
// Skip nodes who's children have not been resolved yet.
263-
case e if !e.childrenResolved => e
264-
265-
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
266-
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType =>
267-
if (b.inputType.acceptsType(commonType)) {
268-
// If the expression accepts the tighest common type, cast to that.
269-
val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)
270-
val newRight = if (right.dataType == commonType) right else Cast(right, commonType)
271-
b.makeCopy(Array(newLeft, newRight))
272-
} else {
273-
// Otherwise, don't do anything with the expression.
274-
b
275-
}
276-
}.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
277-
}
278217
}
279218
}
280219

@@ -734,6 +673,40 @@ object HiveTypeCoercion {
734673
implicitCast(in, expected).getOrElse(in)
735674
}
736675
e.withNewChildren(children)
676+
677+
case e: ExpectsInputTypes if e.inputTypes.nonEmpty =>
678+
// Convert NullType into some specific target type for ExpectsInputTypes that don't do
679+
// general implicit casting.
680+
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
681+
if (in.dataType == NullType && !expected.acceptsType(NullType)) {
682+
Cast(in, expected.defaultConcreteType)
683+
} else {
684+
in
685+
}
686+
}
687+
e.withNewChildren(children)
688+
689+
case b @ BinaryOperator(left, right)
690+
if left.dataType == NullType && right.dataType == NullType &&
691+
!b.inputType.acceptsType(NullType) =>
692+
// If both inputs are null type (from null literals), cast the null type into some
693+
// specific type the expression expects, so expressions don't need to handle NullType
694+
val newLeft = Cast(left, b.inputType.defaultConcreteType)
695+
val newRight = Cast(right, b.inputType.defaultConcreteType)
696+
b.makeCopy(Array(newLeft, newRight))
697+
698+
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
699+
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType =>
700+
if (b.inputType.acceptsType(commonType)) {
701+
// If the expression accepts the tighest common type, cast to that.
702+
val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)
703+
val newRight = if (right.dataType == commonType) right else Cast(right, commonType)
704+
b.makeCopy(Array(newLeft, newRight))
705+
} else {
706+
// Otherwise, don't do anything with the expression.
707+
b
708+
}
709+
}.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
737710
}
738711

739712
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -194,26 +194,26 @@ class HiveTypeCoercionSuite extends PlanTest {
194194
Project(Seq(Alias(transformed, "a")()), testRelation))
195195
}
196196

197-
test("RemoveNullTypes for expressions that define ExpectsInputTypes") {
197+
test("cast NullType for expresions that implement ExpectsInputTypes") {
198198
import HiveTypeCoercionSuite._
199199

200-
ruleTest(HiveTypeCoercion.RemoveNullTypes,
200+
ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
201201
AnyTypeUnaryExpression(Literal.create(null, NullType)),
202202
AnyTypeUnaryExpression(Literal.create(null, NullType)))
203203

204-
ruleTest(HiveTypeCoercion.RemoveNullTypes,
204+
ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
205205
NumericTypeUnaryExpression(Literal.create(null, NullType)),
206206
NumericTypeUnaryExpression(Cast(Literal.create(null, NullType), DoubleType)))
207207
}
208208

209-
test("RemoveNullTypes for binary operators") {
209+
test("cast NullType for binary operators") {
210210
import HiveTypeCoercionSuite._
211211

212-
ruleTest(HiveTypeCoercion.RemoveNullTypes,
212+
ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
213213
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
214214
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)))
215215

216-
ruleTest(HiveTypeCoercion.RemoveNullTypes,
216+
ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
217217
NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
218218
NumericTypeBinaryOperator(
219219
Cast(Literal.create(null, NullType), DoubleType),

0 commit comments

Comments
 (0)