Skip to content

Commit f8d51b9

Browse files
committed
[SPARK-40054][SQL] Restore the error handling syntax of try_cast()
### What changes were proposed in this pull request? For the following query ``` SET spark.sql.ansi.enabled=true; SELECT try_cast(1/0 AS string); ``` Spark 3.3 will throw an exception for the division by zero error. In the current master branch, it returns null after the refactoring PR #36703 This PR is to restore the previous error handling syntax. ### Why are the changes needed? 1. Restore the behavior of try_cast() 2. The previous syntax is more reasonable. Users can cleanly catch the exception from the child of `try_cast`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests Closes #37486 from gengliangwang/restoreTryCast. Authored-by: Gengliang Wang <[email protected]> Signed-off-by: Gengliang Wang <[email protected]>
1 parent 1a738e7 commit f8d51b9

File tree

12 files changed

+326
-194
lines changed

12 files changed

+326
-194
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 121 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,33 @@ object Cast {
146146
case _ => false
147147
}
148148

149+
// If the target data type is a complex type which can't have Null values, we should guarantee
150+
// that the casting between the element types won't produce Null results.
151+
def canTryCast(from: DataType, to: DataType): Boolean = (from, to) match {
152+
case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
153+
canCast(fromType, toType) &&
154+
resolvableNullability(fn || forceNullable(fromType, toType), tn)
155+
156+
case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
157+
canCast(fromKey, toKey) &&
158+
(!forceNullable(fromKey, toKey)) &&
159+
canCast(fromValue, toValue) &&
160+
resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
161+
162+
case (StructType(fromFields), StructType(toFields)) =>
163+
fromFields.length == toFields.length &&
164+
fromFields.zip(toFields).forall {
165+
case (fromField, toField) =>
166+
canCast(fromField.dataType, toField.dataType) &&
167+
resolvableNullability(
168+
fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
169+
toField.nullable)
170+
}
171+
172+
case _ =>
173+
Cast.canAnsiCast(from, to)
174+
}
175+
149176
/**
150177
* A tag to identify if a CAST added by the table insertion resolver.
151178
*/
@@ -426,6 +453,19 @@ object Cast {
426453

427454
case _ => s"cannot cast ${from.catalogString} to ${to.catalogString}"
428455
}
456+
457+
def apply(
458+
child: Expression,
459+
dataType: DataType,
460+
ansiEnabled: Boolean): Cast =
461+
Cast(child, dataType, None, EvalMode.fromBoolean(ansiEnabled))
462+
463+
def apply(
464+
child: Expression,
465+
dataType: DataType,
466+
timeZoneId: Option[String],
467+
ansiEnabled: Boolean): Cast =
468+
Cast(child, dataType, timeZoneId, EvalMode.fromBoolean(ansiEnabled))
429469
}
430470

431471
/**
@@ -447,11 +487,11 @@ case class Cast(
447487
child: Expression,
448488
dataType: DataType,
449489
timeZoneId: Option[String] = None,
450-
ansiEnabled: Boolean = SQLConf.get.ansiEnabled) extends UnaryExpression
490+
evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends UnaryExpression
451491
with TimeZoneAwareExpression with NullIntolerant with SupportQueryContext {
452492

453493
def this(child: Expression, dataType: DataType, timeZoneId: Option[String]) =
454-
this(child, dataType, timeZoneId, ansiEnabled = SQLConf.get.ansiEnabled)
494+
this(child, dataType, timeZoneId, evalMode = EvalMode.fromSQLConf(SQLConf.get))
455495

456496
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
457497
copy(timeZoneId = Option(timeZoneId))
@@ -460,29 +500,57 @@ case class Cast(
460500

461501
final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CAST)
462502

463-
private def typeCheckFailureMessage: String = if (ansiEnabled) {
464-
if (getTagValue(Cast.BY_TABLE_INSERTION).isDefined) {
465-
Cast.typeCheckFailureMessage(child.dataType, dataType,
466-
Some(SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.LEGACY.toString))
467-
} else {
468-
Cast.typeCheckFailureMessage(child.dataType, dataType,
469-
Some(SQLConf.ANSI_ENABLED.key -> "false"))
470-
}
471-
} else {
472-
s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}"
503+
def ansiEnabled: Boolean = {
504+
evalMode == EvalMode.ANSI || evalMode == EvalMode.TRY
505+
}
506+
507+
// Whether this expression is used for `try_cast()`.
508+
def isTryCast: Boolean = {
509+
evalMode == EvalMode.TRY
510+
}
511+
512+
private def typeCheckFailureMessage: String = evalMode match {
513+
case EvalMode.ANSI =>
514+
if (getTagValue(Cast.BY_TABLE_INSERTION).isDefined) {
515+
Cast.typeCheckFailureMessage(child.dataType, dataType,
516+
Some(SQLConf.STORE_ASSIGNMENT_POLICY.key ->
517+
SQLConf.StoreAssignmentPolicy.LEGACY.toString))
518+
} else {
519+
Cast.typeCheckFailureMessage(child.dataType, dataType,
520+
Some(SQLConf.ANSI_ENABLED.key -> "false"))
521+
}
522+
case EvalMode.TRY =>
523+
Cast.typeCheckFailureMessage(child.dataType, dataType, None)
524+
case _ =>
525+
s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}"
473526
}
474527

475528
override def checkInputDataTypes(): TypeCheckResult = {
476-
if (ansiEnabled && Cast.canAnsiCast(child.dataType, dataType)) {
477-
TypeCheckResult.TypeCheckSuccess
478-
} else if (!ansiEnabled && Cast.canCast(child.dataType, dataType)) {
529+
val canCast = evalMode match {
530+
case EvalMode.LEGACY => Cast.canCast(child.dataType, dataType)
531+
case EvalMode.ANSI => Cast.canAnsiCast(child.dataType, dataType)
532+
case EvalMode.TRY => Cast.canTryCast(child.dataType, dataType)
533+
case other => throw new IllegalArgumentException(s"Unknown EvalMode value: $other")
534+
}
535+
if (canCast) {
479536
TypeCheckResult.TypeCheckSuccess
480537
} else {
481538
TypeCheckResult.TypeCheckFailure(typeCheckFailureMessage)
482539
}
483540
}
484541

485-
override def nullable: Boolean = child.nullable || Cast.forceNullable(child.dataType, dataType)
542+
override def nullable: Boolean = if (!isTryCast) {
543+
child.nullable || Cast.forceNullable(child.dataType, dataType)
544+
} else {
545+
(child.dataType, dataType) match {
546+
case (StringType, BinaryType) => child.nullable
547+
// TODO: Implement a more accurate method for checking whether a decimal value can be cast
548+
// as integral types without overflow. Currently, the cast can overflow even if
549+
// "Cast.canUpCast" method returns true.
550+
case (_: DecimalType, _: IntegralType) => true
551+
case _ => child.nullable || !Cast.canUpCast(child.dataType, dataType)
552+
}
553+
}
486554

487555
override def initQueryContext(): Option[SQLQueryContext] = if (ansiEnabled) {
488556
Some(origin.context)
@@ -1146,7 +1214,7 @@ case class Cast(
11461214
})
11471215
}
11481216

1149-
protected[this] def cast(from: DataType, to: DataType): Any => Any = {
1217+
private def castInternal(from: DataType, to: DataType): Any => Any = {
11501218
// If the cast does not change the structure, then we don't really need to cast anything.
11511219
// We can return what the children return. Same thing should happen in the codegen path.
11521220
if (DataType.equalsStructurally(from, to)) {
@@ -1188,6 +1256,20 @@ case class Cast(
11881256
}
11891257
}
11901258

1259+
private def cast(from: DataType, to: DataType): Any => Any = {
1260+
if (!isTryCast) {
1261+
castInternal(from, to)
1262+
} else {
1263+
(input: Any) =>
1264+
try {
1265+
castInternal(from, to)(input)
1266+
} catch {
1267+
case _: Exception =>
1268+
null
1269+
}
1270+
}
1271+
}
1272+
11911273
protected[this] lazy val cast: Any => Any = cast(child.dataType, dataType)
11921274

11931275
protected override def nullSafeEval(input: Any): Any = cast(input)
@@ -1253,11 +1335,22 @@ case class Cast(
12531335
protected[this] def castCode(ctx: CodegenContext, input: ExprValue, inputIsNull: ExprValue,
12541336
result: ExprValue, resultIsNull: ExprValue, resultType: DataType, cast: CastFunction): Block = {
12551337
val javaType = JavaCode.javaType(resultType)
1338+
val castCodeWithTryCatchIfNeeded = if (!isTryCast) {
1339+
s"${cast(input, result, resultIsNull)}"
1340+
} else {
1341+
s"""
1342+
|try {
1343+
| ${cast(input, result, resultIsNull)}
1344+
|} catch (Exception e) {
1345+
| $resultIsNull = true;
1346+
|}
1347+
|""".stripMargin
1348+
}
12561349
code"""
12571350
boolean $resultIsNull = $inputIsNull;
12581351
$javaType $result = ${CodeGenerator.defaultValue(resultType)};
12591352
if (!$inputIsNull) {
1260-
${cast(input, result, resultIsNull)}
1353+
$castCodeWithTryCatchIfNeeded
12611354
}
12621355
"""
12631356
}
@@ -2345,14 +2438,22 @@ case class Cast(
23452438
"""
23462439
}
23472440

2348-
override def toString: String = s"cast($child as ${dataType.simpleString})"
2441+
override def prettyName: String = if (!isTryCast) {
2442+
"cast"
2443+
} else {
2444+
"try_cast"
2445+
}
2446+
2447+
override def toString: String = {
2448+
s"$prettyName($child as ${dataType.simpleString})"
2449+
}
23492450

23502451
override def sql: String = dataType match {
23512452
// HiveQL doesn't allow casting to complex types. For logical plans translated from HiveQL, this
23522453
// type of casting can only be introduced by the analyzer, and can be omitted when converting
23532454
// back to SQL query string.
23542455
case _: ArrayType | _: MapType | _: StructType => child.sql
2355-
case _ => s"CAST(${child.sql} AS ${dataType.sql})"
2456+
case _ => s"${prettyName.toUpperCase(Locale.ROOT)}(${child.sql} AS ${dataType.sql})"
23562457
}
23572458
}
23582459

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.catalyst.expressions
18+
19+
import org.apache.spark.sql.internal.SQLConf
20+
21+
/**
22+
* Expression evaluation modes.
23+
* - LEGACY: the default evaluation mode, which is compliant to Hive SQL.
24+
* - ANSI: a evaluation mode which is compliant to ANSI SQL standard.
25+
* - TRY: a evaluation mode for `try_*` functions. It is identical to ANSI evaluation mode
26+
* except for returning null result on errors.
27+
*/
28+
29+
object EvalMode extends Enumeration {
30+
val LEGACY, ANSI, TRY = Value
31+
32+
def fromSQLConf(conf: SQLConf): Value = if (conf.ansiEnabled) {
33+
ANSI
34+
} else {
35+
LEGACY
36+
}
37+
38+
def fromBoolean(ansiEnabled: Boolean): Value = if (ansiEnabled) {
39+
ANSI
40+
} else {
41+
LEGACY
42+
}
43+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala

Lines changed: 1 addition & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,9 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
21-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
22-
import org.apache.spark.sql.catalyst.expressions.Cast.{forceNullable, resolvableNullability}
2321
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
2422
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
25-
import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern}
26-
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
23+
import org.apache.spark.sql.types.DataType
2724

2825
case class TryEval(child: Expression) extends UnaryExpression with NullIntolerant {
2926
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -56,87 +53,6 @@ case class TryEval(child: Expression) extends UnaryExpression with NullIntoleran
5653
copy(child = newChild)
5754
}
5855

59-
/**
60-
* A special version of [[Cast]] with ansi mode on. It performs the same operation (i.e. converts a
61-
* value of one data type into another data type), but returns a NULL value instead of raising an
62-
* error when the conversion can not be performed.
63-
*/
64-
@ExpressionDescription(
65-
usage = """
66-
_FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.
67-
This expression is identical to CAST with configuration `spark.sql.ansi.enabled` as
68-
true, except it returns NULL instead of raising an error. Note that the behavior of this
69-
expression doesn't depend on configuration `spark.sql.ansi.enabled`.
70-
""",
71-
examples = """
72-
Examples:
73-
> SELECT _FUNC_('10' as int);
74-
10
75-
> SELECT _FUNC_(1234567890123L as int);
76-
null
77-
""",
78-
since = "3.2.0",
79-
group = "conversion_funcs")
80-
case class TryCast(child: Expression, toType: DataType, timeZoneId: Option[String] = None)
81-
extends UnaryExpression with RuntimeReplaceable with TimeZoneAwareExpression {
82-
83-
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
84-
copy(timeZoneId = Option(timeZoneId))
85-
86-
override def nodePatternsInternal(): Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE)
87-
88-
// When this cast involves TimeZone, it's only resolved if the timeZoneId is set;
89-
// Otherwise behave like Expression.resolved.
90-
override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess &&
91-
(!Cast.needsTimeZone(child.dataType, toType) || timeZoneId.isDefined)
92-
93-
override lazy val replacement = {
94-
TryEval(Cast(child, toType, timeZoneId = timeZoneId, ansiEnabled = true))
95-
}
96-
97-
// If the target data type is a complex type which can't have Null values, we should guarantee
98-
// that the casting between the element types won't produce Null results.
99-
private def canCast(from: DataType, to: DataType): Boolean = (from, to) match {
100-
case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
101-
canCast(fromType, toType) &&
102-
resolvableNullability(fn || forceNullable(fromType, toType), tn)
103-
104-
case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
105-
canCast(fromKey, toKey) &&
106-
(!forceNullable(fromKey, toKey)) &&
107-
canCast(fromValue, toValue) &&
108-
resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
109-
110-
case (StructType(fromFields), StructType(toFields)) =>
111-
fromFields.length == toFields.length &&
112-
fromFields.zip(toFields).forall {
113-
case (fromField, toField) =>
114-
canCast(fromField.dataType, toField.dataType) &&
115-
resolvableNullability(
116-
fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
117-
toField.nullable)
118-
}
119-
120-
case _ =>
121-
Cast.canAnsiCast(from, to)
122-
}
123-
124-
override def checkInputDataTypes(): TypeCheckResult = {
125-
if (canCast(child.dataType, dataType)) {
126-
TypeCheckResult.TypeCheckSuccess
127-
} else {
128-
TypeCheckResult.TypeCheckFailure(Cast.typeCheckFailureMessage(child.dataType, toType, None))
129-
}
130-
}
131-
132-
override def toString: String = s"try_cast($child as ${dataType.simpleString})"
133-
134-
override def sql: String = s"TRY_CAST(${child.sql} AS ${dataType.sql})"
135-
136-
override protected def withNewChildInternal(newChild: Expression): Expression =
137-
this.copy(child = newChild)
138-
}
139-
14056
// scalastyle:off line.size.limit
14157
@ExpressionDescription(
14258
usage = "_FUNC_(expr1, expr2) - Returns the sum of `expr1`and `expr2` and the result is null on overflow. " +

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,9 +1795,9 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
17951795
cast
17961796

17971797
case SqlBaseParser.TRY_CAST =>
1798-
// `TryCast` can only be user-specified and we don't need to set the USER_SPECIFIED_CAST
1799-
// tag, which is only used by `Cast`
1800-
TryCast(expression(ctx.expression), dataType)
1798+
val cast = Cast(expression(ctx.expression), dataType, evalMode = EvalMode.TRY)
1799+
cast.setTagValue(Cast.USER_SPECIFIED_CAST, true)
1800+
cast
18011801
}
18021802
}
18031803

0 commit comments

Comments
 (0)