Skip to content

Commit 46175d1

Browse files
committed
[SPARK-39321][SQL] Refactor TryCast to use RuntimeReplaceable
### What changes were proposed in this pull request? This PR refactors `TryCast` to use `RuntimeReplaceable`, so that we don't need `CastBase` anymore. The unit tests are also simplified because we don't need to check the execution of `RuntimeReplaceable`, but only the analysis behavior. ### Why are the changes needed? code cleanup ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests Closes apache#36703 from cloud-fan/cast. Authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent f2f73ed commit 46175d1

File tree

13 files changed

+347
-451
lines changed

13 files changed

+347
-451
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 44 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -425,29 +425,54 @@ object Cast {
425425
}
426426
}
427427

428-
abstract class CastBase extends UnaryExpression
429-
with TimeZoneAwareExpression
430-
with NullIntolerant
431-
with SupportQueryContext {
428+
/**
429+
* Cast the child expression to the target data type.
430+
*
431+
* When cast from/to timezone related types, we need timeZoneId, which will be resolved with
432+
* session local timezone by an analyzer [[ResolveTimeZone]].
433+
*/
434+
@ExpressionDescription(
435+
usage = "_FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.",
436+
examples = """
437+
Examples:
438+
> SELECT _FUNC_('10' as int);
439+
10
440+
""",
441+
since = "1.0.0",
442+
group = "conversion_funcs")
443+
case class Cast(
444+
child: Expression,
445+
dataType: DataType,
446+
timeZoneId: Option[String] = None,
447+
ansiEnabled: Boolean = SQLConf.get.ansiEnabled) extends UnaryExpression
448+
with TimeZoneAwareExpression with NullIntolerant with SupportQueryContext {
432449

433-
def child: Expression
450+
def this(child: Expression, dataType: DataType, timeZoneId: Option[String]) =
451+
this(child, dataType, timeZoneId, ansiEnabled = SQLConf.get.ansiEnabled)
434452

435-
def dataType: DataType
453+
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
454+
copy(timeZoneId = Option(timeZoneId))
436455

437-
/**
438-
* Returns true iff we can cast `from` type to `to` type.
439-
*/
440-
def canCast(from: DataType, to: DataType): Boolean
456+
override protected def withNewChildInternal(newChild: Expression): Cast = copy(child = newChild)
441457

442-
/**
443-
* Returns the error message if casting from one type to another one is invalid.
444-
*/
445-
def typeCheckFailureMessage: String
458+
final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CAST)
446459

447-
override def toString: String = s"cast($child as ${dataType.simpleString})"
460+
private def typeCheckFailureMessage: String = if (ansiEnabled) {
461+
if (getTagValue(Cast.BY_TABLE_INSERTION).isDefined) {
462+
Cast.typeCheckFailureMessage(child.dataType, dataType,
463+
Some(SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.LEGACY.toString))
464+
} else {
465+
Cast.typeCheckFailureMessage(child.dataType, dataType,
466+
Some(SQLConf.ANSI_ENABLED.key -> "false"))
467+
}
468+
} else {
469+
s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}"
470+
}
448471

449472
override def checkInputDataTypes(): TypeCheckResult = {
450-
if (canCast(child.dataType, dataType)) {
473+
if (ansiEnabled && Cast.canAnsiCast(child.dataType, dataType)) {
474+
TypeCheckResult.TypeCheckSuccess
475+
} else if (!ansiEnabled && Cast.canCast(child.dataType, dataType)) {
451476
TypeCheckResult.TypeCheckSuccess
452477
} else {
453478
TypeCheckResult.TypeCheckFailure(typeCheckFailureMessage)
@@ -456,8 +481,6 @@ abstract class CastBase extends UnaryExpression
456481

457482
override def nullable: Boolean = child.nullable || Cast.forceNullable(child.dataType, dataType)
458483

459-
protected def ansiEnabled: Boolean
460-
461484
override def initQueryContext(): String = if (ansiEnabled) {
462485
origin.context
463486
} else {
@@ -470,7 +493,7 @@ abstract class CastBase extends UnaryExpression
470493
childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined)
471494

472495
override lazy val preCanonicalized: Expression = {
473-
val basic = withNewChildren(Seq(child.preCanonicalized)).asInstanceOf[CastBase]
496+
val basic = withNewChildren(Seq(child.preCanonicalized)).asInstanceOf[Cast]
474497
if (timeZoneId.isDefined && !needsTimeZone) {
475498
basic.withTimeZone(null)
476499
} else {
@@ -2246,6 +2269,8 @@ abstract class CastBase extends UnaryExpression
22462269
"""
22472270
}
22482271

2272+
override def toString: String = s"cast($child as ${dataType.simpleString})"
2273+
22492274
override def sql: String = dataType match {
22502275
// HiveQL doesn't allow casting to complex types. For logical plans translated from HiveQL, this
22512276
// type of casting can only be introduced by the analyzer, and can be omitted when converting
@@ -2255,57 +2280,6 @@ abstract class CastBase extends UnaryExpression
22552280
}
22562281
}
22572282

2258-
/**
2259-
* Cast the child expression to the target data type.
2260-
*
2261-
* When cast from/to timezone related types, we need timeZoneId, which will be resolved with
2262-
* session local timezone by an analyzer [[ResolveTimeZone]].
2263-
*/
2264-
@ExpressionDescription(
2265-
usage = "_FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.",
2266-
examples = """
2267-
Examples:
2268-
> SELECT _FUNC_('10' as int);
2269-
10
2270-
""",
2271-
since = "1.0.0",
2272-
group = "conversion_funcs")
2273-
case class Cast(
2274-
child: Expression,
2275-
dataType: DataType,
2276-
timeZoneId: Option[String] = None,
2277-
override val ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
2278-
extends CastBase {
2279-
2280-
def this(child: Expression, dataType: DataType, timeZoneId: Option[String]) =
2281-
this(child, dataType, timeZoneId, ansiEnabled = SQLConf.get.ansiEnabled)
2282-
2283-
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
2284-
copy(timeZoneId = Option(timeZoneId))
2285-
2286-
final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CAST)
2287-
2288-
override def canCast(from: DataType, to: DataType): Boolean = if (ansiEnabled) {
2289-
Cast.canAnsiCast(from, to)
2290-
} else {
2291-
Cast.canCast(from, to)
2292-
}
2293-
2294-
override def typeCheckFailureMessage: String = if (ansiEnabled) {
2295-
if (getTagValue(Cast.BY_TABLE_INSERTION).isDefined) {
2296-
Cast.typeCheckFailureMessage(child.dataType, dataType,
2297-
Some(SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.LEGACY.toString))
2298-
} else {
2299-
Cast.typeCheckFailureMessage(child.dataType, dataType,
2300-
Some(SQLConf.ANSI_ENABLED.key -> "false"))
2301-
}
2302-
} else {
2303-
s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}"
2304-
}
2305-
2306-
override protected def withNewChildInternal(newChild: Expression): Cast = copy(child = newChild)
2307-
}
2308-
23092283
/**
23102284
* Cast the child expression to the target data type, but will throw error if the cast might
23112285
* truncate, e.g. long -> int, timestamp -> data.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala

Lines changed: 0 additions & 122 deletions
This file was deleted.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
21+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
22+
import org.apache.spark.sql.catalyst.expressions.Cast.{forceNullable, resolvableNullability}
2123
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
2224
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
23-
import org.apache.spark.sql.types.DataType
25+
import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern}
26+
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
2427

2528
case class TryEval(child: Expression) extends UnaryExpression with NullIntolerant {
2629
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -53,6 +56,87 @@ case class TryEval(child: Expression) extends UnaryExpression with NullIntoleran
5356
copy(child = newChild)
5457
}
5558

59+
/**
60+
* A special version of [[Cast]] with ansi mode on. It performs the same operation (i.e. converts a
61+
* value of one data type into another data type), but returns a NULL value instead of raising an
62+
* error when the conversion can not be performed.
63+
*/
64+
@ExpressionDescription(
65+
usage = """
66+
_FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.
67+
This expression is identical to CAST with configuration `spark.sql.ansi.enabled` as
68+
true, except it returns NULL instead of raising an error. Note that the behavior of this
69+
expression doesn't depend on configuration `spark.sql.ansi.enabled`.
70+
""",
71+
examples = """
72+
Examples:
73+
> SELECT _FUNC_('10' as int);
74+
10
75+
> SELECT _FUNC_(1234567890123L as int);
76+
null
77+
""",
78+
since = "3.2.0",
79+
group = "conversion_funcs")
80+
case class TryCast(child: Expression, toType: DataType, timeZoneId: Option[String] = None)
81+
extends UnaryExpression with RuntimeReplaceable with TimeZoneAwareExpression {
82+
83+
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
84+
copy(timeZoneId = Option(timeZoneId))
85+
86+
override def nodePatternsInternal(): Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE)
87+
88+
// When this cast involves TimeZone, it's only resolved if the timeZoneId is set;
89+
// Otherwise behave like Expression.resolved.
90+
override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess &&
91+
(!Cast.needsTimeZone(child.dataType, toType) || timeZoneId.isDefined)
92+
93+
override lazy val replacement = {
94+
TryEval(Cast(child, toType, timeZoneId = timeZoneId, ansiEnabled = true))
95+
}
96+
97+
// If the target data type is a complex type which can't have Null values, we should guarantee
98+
// that the casting between the element types won't produce Null results.
99+
private def canCast(from: DataType, to: DataType): Boolean = (from, to) match {
100+
case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
101+
canCast(fromType, toType) &&
102+
resolvableNullability(fn || forceNullable(fromType, toType), tn)
103+
104+
case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
105+
canCast(fromKey, toKey) &&
106+
(!forceNullable(fromKey, toKey)) &&
107+
canCast(fromValue, toValue) &&
108+
resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
109+
110+
case (StructType(fromFields), StructType(toFields)) =>
111+
fromFields.length == toFields.length &&
112+
fromFields.zip(toFields).forall {
113+
case (fromField, toField) =>
114+
canCast(fromField.dataType, toField.dataType) &&
115+
resolvableNullability(
116+
fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
117+
toField.nullable)
118+
}
119+
120+
case _ =>
121+
Cast.canAnsiCast(from, to)
122+
}
123+
124+
override def checkInputDataTypes(): TypeCheckResult = {
125+
if (canCast(child.dataType, dataType)) {
126+
TypeCheckResult.TypeCheckSuccess
127+
} else {
128+
TypeCheckResult.TypeCheckFailure(Cast.typeCheckFailureMessage(child.dataType, toType, None))
129+
}
130+
}
131+
132+
override def toString: String = s"try_cast($child as ${dataType.simpleString})"
133+
134+
override def sql: String = s"TRY_CAST(${child.sql} AS ${dataType.sql})"
135+
136+
override protected def withNewChildInternal(newChild: Expression): Expression =
137+
this.copy(child = newChild)
138+
}
139+
56140
// scalastyle:off line.size.limit
57141
@ExpressionDescription(
58142
usage = "_FUNC_(expr1, expr2) - Returns the sum of `expr1`and `expr2` and the result is null on overflow. " +

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,8 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
631631
case _: UnaryMathExpression | _: Abs | _: Bin | _: Factorial | _: Hex => true
632632
case _: String2StringExpression | _: Ascii | _: Base64 | _: BitLength | _: Chr | _: Length =>
633633
true
634-
case _: CastBase => true
634+
case _: Cast => true
635+
case _: TryEval => true
635636
case _: GetDateField | _: LastDay => true
636637
case _: ExtractIntervalPart[_] => true
637638
case _: ArraySetLike => true

0 commit comments

Comments
 (0)