Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.analysis

import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,19 +214,6 @@ object HiveTypeCoercion {
}

Union(newLeft, newRight)

// Also widen types for BinaryOperator.
case q: LogicalPlan => q transformExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType =>
val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
b.makeCopy(Array(newLeft, newRight))
}.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
}
}
}

Expand Down Expand Up @@ -672,20 +659,44 @@ object HiveTypeCoercion {
}

/**
* Casts types according to the expected input types for Expressions that have the trait
* [[ExpectsInputTypes]].
* Casts types according to the expected input types for [[Expression]]s.
*/
object ImplicitTypeCasts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

case e: ExpectsInputTypes if (e.inputTypes.nonEmpty) =>
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType =>
if (b.inputType.acceptsType(commonType)) {
// If the expression accepts the tighest common type, cast to that.
val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)
val newRight = if (right.dataType == commonType) right else Cast(right, commonType)
b.makeCopy(Array(newLeft, newRight))
} else {
// Otherwise, don't do anything with the expression.
b
}
}.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.

case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty =>
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
// If we cannot do the implicit cast, just use the original input.
implicitCast(in, expected).getOrElse(in)
}
e.withNewChildren(children)

case e: ExpectsInputTypes if e.inputTypes.nonEmpty =>
// Convert NullType into some specific target type for ExpectsInputTypes that don't do
// general implicit casting.
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
if (in.dataType == NullType && !expected.acceptsType(NullType)) {
Cast(in, expected.defaultConcreteType)
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if expected.acceptsType(in.dateType) == false, probably we'd better to raise a TypeChecking exception?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that happens during CheckAnalysis when we report errors.

in
}
}
e.withNewChildren(children)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.types.AbstractDataType

import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.ImplicitTypeCasts

/**
* An trait that gets mixin to define the expected input types of an expression.
*
* This trait is typically used by operator expressions (e.g. [[Add]], [[Subtract]]) to define
* expected input types without any implicit casting.
*
* Most function expressions (e.g. [[Substring]] should extends [[ImplicitCastInputTypes]]) instead.
*/
trait ExpectsInputTypes { self: Expression =>

Expand All @@ -40,7 +45,7 @@ trait ExpectsInputTypes { self: Expression =>
val mismatches = children.zip(inputTypes).zipWithIndex.collect {
case ((child, expected), idx) if !expected.acceptsType(child.dataType) =>
s"argument ${idx + 1} is expected to be of type ${expected.simpleString}, " +
s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}."
s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}."
}

if (mismatches.isEmpty) {
Expand All @@ -50,3 +55,11 @@ trait ExpectsInputTypes { self: Expression =>
}
}
}


/**
* A mixin for the analyzer to perform implicit type casting using [[ImplicitTypeCasts]].
*/
trait ImplicitCastInputTypes extends ExpectsInputTypes { self: Expression =>
// No other methods
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,20 @@ import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._

////////////////////////////////////////////////////////////////////////////////////////////////////
// This file defines the basic expression abstract classes in Catalyst, including:
// Expression: the base expression abstract class
// LeafExpression
// UnaryExpression
// BinaryExpression
// BinaryOperator
//
// For details, see their classdocs.
////////////////////////////////////////////////////////////////////////////////////////////////////

/**
* An expression in Catalyst.
*
* If an expression wants to be exposed in the function registry (so users can call it with
* "name(arguments...)", the concrete implementation must be a case class whose constructor
* arguments are all Expressions types.
Expand Down Expand Up @@ -335,15 +347,41 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express


/**
* An expression that has two inputs that are expected to the be same type. If the two inputs have
* different types, the analyzer will find the tightest common type and do the proper type casting.
* A [[BinaryExpression]] that is an operator, with two properties:
*
* 1. The string representation is "x symbol y", rather than "funcName(x, y)".
* 2. Two inputs are expected to the be same type. If the two inputs have different types,
* the analyzer will find the tightest common type and do the proper type casting.
*/
abstract class BinaryOperator extends BinaryExpression {
abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
self: Product =>

/**
* Expected input type from both left/right child expressions, similar to the
* [[ImplicitCastInputTypes]] trait.
*/
def inputType: AbstractDataType

def symbol: String

override def toString: String = s"($left $symbol $right)"

override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType)

override def checkInputDataTypes(): TypeCheckResult = {
// First call the checker for ExpectsInputTypes, and then check whether left and right have
// the same type.
super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
if (left.dataType != right.dataType) {
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
} else {
TypeCheckResult.TypeCheckSuccess
}
case TypeCheckResult.TypeCheckFailure(msg) => TypeCheckResult.TypeCheckFailure(msg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simplify this to case fail => fail?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a good idea

}
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ case class ScalaUDF(
function: AnyRef,
dataType: DataType,
children: Seq[Expression],
inputTypes: Seq[DataType] = Nil) extends Expression with ExpectsInputTypes {
inputTypes: Seq[DataType] = Nil) extends Expression with ImplicitCastInputTypes {

override def nullable: Boolean = true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,19 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

abstract class UnaryArithmetic extends UnaryExpression {
self: Product =>

case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

override def dataType: DataType = child.dataType
}

case class UnaryMinus(child: Expression) extends UnaryArithmetic {
override def toString: String = s"-$child"

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "operator -")

private lazy val numeric = TypeUtils.getNumeric(dataType)

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
Expand All @@ -45,9 +41,13 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic {
protected override def nullSafeEval(input: Any): Any = numeric.negate(input)
}

case class UnaryPositive(child: Expression) extends UnaryArithmetic {
case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def prettyName: String = "positive"

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

override def dataType: DataType = child.dataType

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
defineCodeGen(ctx, ev, c => c)

Expand All @@ -57,9 +57,11 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic {
/**
* A function that get the absolute value of the numeric value.
*/
case class Abs(child: Expression) extends UnaryArithmetic {
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function abs")
case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

override def dataType: DataType = child.dataType

private lazy val numeric = TypeUtils.getNumeric(dataType)

Expand All @@ -71,18 +73,6 @@ abstract class BinaryArithmetic extends BinaryOperator {

override def dataType: DataType = left.dataType

override def checkInputDataTypes(): TypeCheckResult = {
if (left.dataType != right.dataType) {
TypeCheckResult.TypeCheckFailure(
s"differing types in ${this.getClass.getSimpleName} " +
s"(${left.dataType} and ${right.dataType}).")
} else {
checkTypesInternal(dataType)
}
}

protected def checkTypesInternal(t: DataType): TypeCheckResult

/** Name of the function for this expression on a [[Decimal]] type. */
def decimalMethod: String =
sys.error("BinaryArithmetics must override either decimalMethod or genCode")
Expand All @@ -104,62 +94,61 @@ private[sql] object BinaryArithmetic {
}

case class Add(left: Expression, right: Expression) extends BinaryArithmetic {

override def inputType: AbstractDataType = NumericType

override def symbol: String = "+"
override def decimalMethod: String = "$plus"

override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

private lazy val numeric = TypeUtils.getNumeric(dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2)
}

case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {

override def inputType: AbstractDataType = NumericType

override def symbol: String = "-"
override def decimalMethod: String = "$minus"

override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

private lazy val numeric = TypeUtils.getNumeric(dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2)
}

case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {

override def inputType: AbstractDataType = NumericType

override def symbol: String = "*"
override def decimalMethod: String = "$times"

override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

private lazy val numeric = TypeUtils.getNumeric(dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
}

case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {

override def inputType: AbstractDataType = NumericType

override def symbol: String = "/"
override def decimalMethod: String = "$div"

override def nullable: Boolean = true

override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

private lazy val div: (Any, Any) => Any = dataType match {
case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot
Expand Down Expand Up @@ -215,17 +204,16 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
}

case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {

override def inputType: AbstractDataType = NumericType

override def symbol: String = "%"
override def decimalMethod: String = "remainder"

override def nullable: Boolean = true

override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

private lazy val integral = dataType match {
case i: IntegralType => i.integral.asInstanceOf[Integral[Any]]
case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]]
Expand Down Expand Up @@ -281,10 +269,11 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
}

case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
override def nullable: Boolean = left.nullable && right.nullable
// TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least.

protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForOrderingExpr(t, "function maxOf")
override def inputType: AbstractDataType = TypeCollection.Ordered

override def nullable: Boolean = left.nullable && right.nullable

private lazy val ordering = TypeUtils.getOrdering(dataType)

Expand Down Expand Up @@ -335,10 +324,11 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
}

case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
override def nullable: Boolean = left.nullable && right.nullable
// TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least.

protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForOrderingExpr(t, "function minOf")
override def inputType: AbstractDataType = TypeCollection.Ordered

override def nullable: Boolean = left.nullable && right.nullable

private lazy val ordering = TypeUtils.getOrdering(dataType)

Expand Down
Loading