Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -425,29 +425,54 @@ object Cast {
}
}

abstract class CastBase extends UnaryExpression
with TimeZoneAwareExpression
with NullIntolerant
with SupportQueryContext {
/**
* Cast the child expression to the target data type.
*
* When cast from/to timezone related types, we need timeZoneId, which will be resolved with
* session local timezone by an analyzer [[ResolveTimeZone]].
*/
@ExpressionDescription(
usage = "_FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.",
examples = """
Examples:
> SELECT _FUNC_('10' as int);
10
""",
since = "1.0.0",
group = "conversion_funcs")
case class Cast(
child: Expression,
dataType: DataType,
timeZoneId: Option[String] = None,
ansiEnabled: Boolean = SQLConf.get.ansiEnabled) extends UnaryExpression
with TimeZoneAwareExpression with NullIntolerant with SupportQueryContext {

def child: Expression
def this(child: Expression, dataType: DataType, timeZoneId: Option[String]) =
this(child, dataType, timeZoneId, ansiEnabled = SQLConf.get.ansiEnabled)

def dataType: DataType
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

/**
* Returns true iff we can cast `from` type to `to` type.
*/
def canCast(from: DataType, to: DataType): Boolean
override protected def withNewChildInternal(newChild: Expression): Cast = copy(child = newChild)

/**
* Returns the error message if casting from one type to another one is invalid.
*/
def typeCheckFailureMessage: String
final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CAST)

override def toString: String = s"cast($child as ${dataType.simpleString})"
private def typeCheckFailureMessage: String = if (ansiEnabled) {
if (getTagValue(Cast.BY_TABLE_INSERTION).isDefined) {
Cast.typeCheckFailureMessage(child.dataType, dataType,
Some(SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.LEGACY.toString))
} else {
Cast.typeCheckFailureMessage(child.dataType, dataType,
Some(SQLConf.ANSI_ENABLED.key -> "false"))
}
} else {
s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}"
}

override def checkInputDataTypes(): TypeCheckResult = {
if (canCast(child.dataType, dataType)) {
if (ansiEnabled && Cast.canAnsiCast(child.dataType, dataType)) {
TypeCheckResult.TypeCheckSuccess
} else if (!ansiEnabled && Cast.canCast(child.dataType, dataType)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(typeCheckFailureMessage)
Expand All @@ -456,8 +481,6 @@ abstract class CastBase extends UnaryExpression

override def nullable: Boolean = child.nullable || Cast.forceNullable(child.dataType, dataType)

protected def ansiEnabled: Boolean

override def initQueryContext(): String = if (ansiEnabled) {
origin.context
} else {
Expand All @@ -470,7 +493,7 @@ abstract class CastBase extends UnaryExpression
childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined)

override lazy val preCanonicalized: Expression = {
val basic = withNewChildren(Seq(child.preCanonicalized)).asInstanceOf[CastBase]
val basic = withNewChildren(Seq(child.preCanonicalized)).asInstanceOf[Cast]
if (timeZoneId.isDefined && !needsTimeZone) {
basic.withTimeZone(null)
} else {
Expand Down Expand Up @@ -2246,6 +2269,8 @@ abstract class CastBase extends UnaryExpression
"""
}

override def toString: String = s"cast($child as ${dataType.simpleString})"

override def sql: String = dataType match {
// HiveQL doesn't allow casting to complex types. For logical plans translated from HiveQL, this
// type of casting can only be introduced by the analyzer, and can be omitted when converting
Expand All @@ -2255,57 +2280,6 @@ abstract class CastBase extends UnaryExpression
}
}

/**
* Cast the child expression to the target data type.
*
* When cast from/to timezone related types, we need timeZoneId, which will be resolved with
* session local timezone by an analyzer [[ResolveTimeZone]].
*/
@ExpressionDescription(
usage = "_FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.",
examples = """
Examples:
> SELECT _FUNC_('10' as int);
10
""",
since = "1.0.0",
group = "conversion_funcs")
case class Cast(
child: Expression,
dataType: DataType,
timeZoneId: Option[String] = None,
override val ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
extends CastBase {

def this(child: Expression, dataType: DataType, timeZoneId: Option[String]) =
this(child, dataType, timeZoneId, ansiEnabled = SQLConf.get.ansiEnabled)

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CAST)

override def canCast(from: DataType, to: DataType): Boolean = if (ansiEnabled) {
Cast.canAnsiCast(from, to)
} else {
Cast.canCast(from, to)
}

override def typeCheckFailureMessage: String = if (ansiEnabled) {
if (getTagValue(Cast.BY_TABLE_INSERTION).isDefined) {
Cast.typeCheckFailureMessage(child.dataType, dataType,
Some(SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.LEGACY.toString))
} else {
Cast.typeCheckFailureMessage(child.dataType, dataType,
Some(SQLConf.ANSI_ENABLED.key -> "false"))
}
} else {
s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}"
}

override protected def withNewChildInternal(newChild: Expression): Cast = copy(child = newChild)
}

/**
* Cast the child expression to the target data type, but will throw error if the cast might
* truncate, e.g. long -> int, timestamp -> data.
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.Cast.{forceNullable, resolvableNullability}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern}
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}

case class TryEval(child: Expression) extends UnaryExpression with NullIntolerant {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down Expand Up @@ -53,6 +56,87 @@ case class TryEval(child: Expression) extends UnaryExpression with NullIntoleran
copy(child = newChild)
}

/**
* A special version of [[Cast]] with ansi mode on. It performs the same operation (i.e. converts a
* value of one data type into another data type), but returns a NULL value instead of raising an
* error when the conversion can not be performed.
*/
@ExpressionDescription(
usage = """
_FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.
This expression is identical to CAST with configuration `spark.sql.ansi.enabled` as
true, except it returns NULL instead of raising an error. Note that the behavior of this
expression doesn't depend on configuration `spark.sql.ansi.enabled`.
""",
examples = """
Examples:
> SELECT _FUNC_('10' as int);
10
> SELECT _FUNC_(1234567890123L as int);
null
""",
since = "3.2.0",
group = "conversion_funcs")
case class TryCast(child: Expression, toType: DataType, timeZoneId: Option[String] = None)
extends UnaryExpression with RuntimeReplaceable with TimeZoneAwareExpression {

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

override def nodePatternsInternal(): Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE)

// When this cast involves TimeZone, it's only resolved if the timeZoneId is set;
// Otherwise behave like Expression.resolved.
override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess &&
(!Cast.needsTimeZone(child.dataType, toType) || timeZoneId.isDefined)

override lazy val replacement = {
TryEval(Cast(child, toType, timeZoneId = timeZoneId, ansiEnabled = true))
}

// If the target data type is a complex type which can't have Null values, we should guarantee
// that the casting between the element types won't produce Null results.
private def canCast(from: DataType, to: DataType): Boolean = (from, to) match {
case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
canCast(fromType, toType) &&
resolvableNullability(fn || forceNullable(fromType, toType), tn)

case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
canCast(fromKey, toKey) &&
(!forceNullable(fromKey, toKey)) &&
canCast(fromValue, toValue) &&
resolvableNullability(fn || forceNullable(fromValue, toValue), tn)

case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields).forall {
case (fromField, toField) =>
canCast(fromField.dataType, toField.dataType) &&
resolvableNullability(
fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
toField.nullable)
}

case _ =>
Cast.canAnsiCast(from, to)
}

override def checkInputDataTypes(): TypeCheckResult = {
if (canCast(child.dataType, dataType)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(Cast.typeCheckFailureMessage(child.dataType, toType, None))
}
}

override def toString: String = s"try_cast($child as ${dataType.simpleString})"

override def sql: String = s"TRY_CAST(${child.sql} AS ${dataType.sql})"

override protected def withNewChildInternal(newChild: Expression): Expression =
this.copy(child = newChild)
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(expr1, expr2) - Returns the sum of `expr1`and `expr2` and the result is null on overflow. " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,8 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
case _: UnaryMathExpression | _: Abs | _: Bin | _: Factorial | _: Hex => true
case _: String2StringExpression | _: Ascii | _: Base64 | _: BitLength | _: Chr | _: Length =>
true
case _: CastBase => true
case _: Cast => true
case _: TryEval => true
case _: GetDateField | _: LastDay => true
case _: ExtractIntervalPart[_] => true
case _: ArraySetLike => true
Expand Down
Loading