@@ -146,6 +146,33 @@ object Cast {
146146 case _ => false
147147 }
148148
149+ // If the target data type is a complex type which can't have Null values, we should guarantee
150+ // that the casting between the element types won't produce Null results.
151+ def canTryCast (from : DataType , to : DataType ): Boolean = (from, to) match {
152+ case (ArrayType (fromType, fn), ArrayType (toType, tn)) =>
153+ canCast(fromType, toType) &&
154+ resolvableNullability(fn || forceNullable(fromType, toType), tn)
155+
156+ case (MapType (fromKey, fromValue, fn), MapType (toKey, toValue, tn)) =>
157+ canCast(fromKey, toKey) &&
158+ (! forceNullable(fromKey, toKey)) &&
159+ canCast(fromValue, toValue) &&
160+ resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
161+
162+ case (StructType (fromFields), StructType (toFields)) =>
163+ fromFields.length == toFields.length &&
164+ fromFields.zip(toFields).forall {
165+ case (fromField, toField) =>
166+ canCast(fromField.dataType, toField.dataType) &&
167+ resolvableNullability(
168+ fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
169+ toField.nullable)
170+ }
171+
172+ case _ =>
173+ Cast .canAnsiCast(from, to)
174+ }
175+
149176 /**
150177 * A tag to identify if a CAST added by the table insertion resolver.
151178 */
@@ -426,6 +453,19 @@ object Cast {
426453
427454 case _ => s " cannot cast ${from.catalogString} to ${to.catalogString}"
428455 }
456+
457+ def apply (
458+ child : Expression ,
459+ dataType : DataType ,
460+ ansiEnabled : Boolean ): Cast =
461+ Cast (child, dataType, None , EvalMode .fromBoolean(ansiEnabled))
462+
463+ def apply (
464+ child : Expression ,
465+ dataType : DataType ,
466+ timeZoneId : Option [String ],
467+ ansiEnabled : Boolean ): Cast =
468+ Cast (child, dataType, timeZoneId, EvalMode .fromBoolean(ansiEnabled))
429469}
430470
431471/**
@@ -447,11 +487,11 @@ case class Cast(
447487 child : Expression ,
448488 dataType : DataType ,
449489 timeZoneId : Option [String ] = None ,
450- ansiEnabled : Boolean = SQLConf .get.ansiEnabled ) extends UnaryExpression
490+ evalMode : EvalMode . Value = EvalMode .fromSQLConf( SQLConf .get) ) extends UnaryExpression
451491 with TimeZoneAwareExpression with NullIntolerant with SupportQueryContext {
452492
453493 def this (child : Expression , dataType : DataType , timeZoneId : Option [String ]) =
454- this (child, dataType, timeZoneId, ansiEnabled = SQLConf .get.ansiEnabled )
494+ this (child, dataType, timeZoneId, evalMode = EvalMode .fromSQLConf( SQLConf .get) )
455495
456496 override def withTimeZone (timeZoneId : String ): TimeZoneAwareExpression =
457497 copy(timeZoneId = Option (timeZoneId))
@@ -460,29 +500,57 @@ case class Cast(
460500
461501 final override def nodePatternsInternal (): Seq [TreePattern ] = Seq (CAST )
462502
463- private def typeCheckFailureMessage : String = if (ansiEnabled) {
464- if (getTagValue(Cast .BY_TABLE_INSERTION ).isDefined) {
465- Cast .typeCheckFailureMessage(child.dataType, dataType,
466- Some (SQLConf .STORE_ASSIGNMENT_POLICY .key -> SQLConf .StoreAssignmentPolicy .LEGACY .toString))
467- } else {
468- Cast .typeCheckFailureMessage(child.dataType, dataType,
469- Some (SQLConf .ANSI_ENABLED .key -> " false" ))
470- }
471- } else {
472- s " cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}"
503+ def ansiEnabled : Boolean = {
504+ evalMode == EvalMode .ANSI || evalMode == EvalMode .TRY
505+ }
506+
507+ // Whether this expression is used for `try_cast()`.
508+ def isTryCast : Boolean = {
509+ evalMode == EvalMode .TRY
510+ }
511+
512+ private def typeCheckFailureMessage : String = evalMode match {
513+ case EvalMode .ANSI =>
514+ if (getTagValue(Cast .BY_TABLE_INSERTION ).isDefined) {
515+ Cast .typeCheckFailureMessage(child.dataType, dataType,
516+ Some (SQLConf .STORE_ASSIGNMENT_POLICY .key ->
517+ SQLConf .StoreAssignmentPolicy .LEGACY .toString))
518+ } else {
519+ Cast .typeCheckFailureMessage(child.dataType, dataType,
520+ Some (SQLConf .ANSI_ENABLED .key -> " false" ))
521+ }
522+ case EvalMode .TRY =>
523+ Cast .typeCheckFailureMessage(child.dataType, dataType, None )
524+ case _ =>
525+ s " cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}"
473526 }
474527
475528 override def checkInputDataTypes (): TypeCheckResult = {
476- if (ansiEnabled && Cast .canAnsiCast(child.dataType, dataType)) {
477- TypeCheckResult .TypeCheckSuccess
478- } else if (! ansiEnabled && Cast .canCast(child.dataType, dataType)) {
529+ val canCast = evalMode match {
530+ case EvalMode .LEGACY => Cast .canCast(child.dataType, dataType)
531+ case EvalMode .ANSI => Cast .canAnsiCast(child.dataType, dataType)
532+ case EvalMode .TRY => Cast .canTryCast(child.dataType, dataType)
533+ case other => throw new IllegalArgumentException (s " Unknown EvalMode value: $other" )
534+ }
535+ if (canCast) {
479536 TypeCheckResult .TypeCheckSuccess
480537 } else {
481538 TypeCheckResult .TypeCheckFailure (typeCheckFailureMessage)
482539 }
483540 }
484541
485- override def nullable : Boolean = child.nullable || Cast .forceNullable(child.dataType, dataType)
542+ override def nullable : Boolean = if (! isTryCast) {
543+ child.nullable || Cast .forceNullable(child.dataType, dataType)
544+ } else {
545+ (child.dataType, dataType) match {
546+ case (StringType , BinaryType ) => child.nullable
547+ // TODO: Implement a more accurate method for checking whether a decimal value can be cast
548+ // as integral types without overflow. Currently, the cast can overflow even if
549+ // "Cast.canUpCast" method returns true.
550+ case (_ : DecimalType , _ : IntegralType ) => true
551+ case _ => child.nullable || ! Cast .canUpCast(child.dataType, dataType)
552+ }
553+ }
486554
487555 override def initQueryContext (): Option [SQLQueryContext ] = if (ansiEnabled) {
488556 Some (origin.context)
@@ -1146,7 +1214,7 @@ case class Cast(
11461214 })
11471215 }
11481216
1149- protected [ this ] def cast (from : DataType , to : DataType ): Any => Any = {
1217+ private def castInternal (from : DataType , to : DataType ): Any => Any = {
11501218 // If the cast does not change the structure, then we don't really need to cast anything.
11511219 // We can return what the children return. Same thing should happen in the codegen path.
11521220 if (DataType .equalsStructurally(from, to)) {
@@ -1188,6 +1256,20 @@ case class Cast(
11881256 }
11891257 }
11901258
1259+ private def cast (from : DataType , to : DataType ): Any => Any = {
1260+ if (! isTryCast) {
1261+ castInternal(from, to)
1262+ } else {
1263+ (input : Any ) =>
1264+ try {
1265+ castInternal(from, to)(input)
1266+ } catch {
1267+ case _ : Exception =>
1268+ null
1269+ }
1270+ }
1271+ }
1272+
11911273 protected [this ] lazy val cast : Any => Any = cast(child.dataType, dataType)
11921274
11931275 protected override def nullSafeEval (input : Any ): Any = cast(input)
@@ -1253,11 +1335,22 @@ case class Cast(
12531335 protected [this ] def castCode (ctx : CodegenContext , input : ExprValue , inputIsNull : ExprValue ,
12541336 result : ExprValue , resultIsNull : ExprValue , resultType : DataType , cast : CastFunction ): Block = {
12551337 val javaType = JavaCode .javaType(resultType)
1338+ val castCodeWithTryCatchIfNeeded = if (! isTryCast) {
1339+ s " ${cast(input, result, resultIsNull)}"
1340+ } else {
1341+ s """
1342+ |try {
1343+ | ${cast(input, result, resultIsNull)}
1344+ |} catch (Exception e) {
1345+ | $resultIsNull = true;
1346+ |}
1347+ | """ .stripMargin
1348+ }
12561349 code """
12571350 boolean $resultIsNull = $inputIsNull;
12581351 $javaType $result = ${CodeGenerator .defaultValue(resultType)};
12591352 if (! $inputIsNull) {
1260- ${cast(input, result, resultIsNull)}
1353+ $castCodeWithTryCatchIfNeeded
12611354 }
12621355 """
12631356 }
@@ -2345,14 +2438,22 @@ case class Cast(
23452438 """
23462439 }
23472440
2348- override def toString : String = s " cast( $child as ${dataType.simpleString}) "
2441+ override def prettyName : String = if (! isTryCast) {
2442+ " cast"
2443+ } else {
2444+ " try_cast"
2445+ }
2446+
2447+ override def toString : String = {
2448+ s " $prettyName( $child as ${dataType.simpleString}) "
2449+ }
23492450
23502451 override def sql : String = dataType match {
23512452 // HiveQL doesn't allow casting to complex types. For logical plans translated from HiveQL, this
23522453 // type of casting can only be introduced by the analyzer, and can be omitted when converting
23532454 // back to SQL query string.
23542455 case _ : ArrayType | _ : MapType | _ : StructType => child.sql
2355- case _ => s " CAST ( ${child.sql} AS ${dataType.sql}) "
2456+ case _ => s " ${prettyName.toUpperCase( Locale . ROOT )} ( ${child.sql} AS ${dataType.sql}) "
23562457 }
23572458}
23582459
0 commit comments