Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class SqlLexical extends StdLexical {
",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>"
)

protected override def processIdent(name: String) = {
override protected def processIdent(name: String) = {
val token = normalizeKeyword(name)
if (reserved contains token) Keyword(token) else Identifier(name)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w

private[this] lazy val cast: Any => Any = cast(child.dataType, dataType)

protected override def nullSafeEval(input: Any): Any = cast(input)
override protected def nullSafeEval(input: Any): Any = cast(input)

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
// TODO: Add support for more data types.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,18 @@ abstract class BinaryOperator extends BinaryExpression {
def symbol: String

override def toString: String = s"($left $symbol $right)"

override def checkInputDataTypes(): TypeCheckResult = {
if (left.dataType != right.dataType) {
TypeCheckResult.TypeCheckFailure(
s"differing types in ${this.getClass.getSimpleName} " +
s"(${left.dataType} and ${right.dataType}).")
} else {
checkTypesInternal(dataType)
}
}

protected def checkTypesInternal(t: DataType): TypeCheckResult
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic {
case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))")
}

protected override def nullSafeEval(input: Any): Any = numeric.negate(input)
override protected def nullSafeEval(input: Any): Any = numeric.negate(input)
}

case class UnaryPositive(child: Expression) extends UnaryArithmetic {
Expand All @@ -51,7 +51,7 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
defineCodeGen(ctx, ev, c => c)

protected override def nullSafeEval(input: Any): Any = input
override protected def nullSafeEval(input: Any): Any = input
}

/**
Expand All @@ -63,26 +63,14 @@ case class Abs(child: Expression) extends UnaryArithmetic {

private lazy val numeric = TypeUtils.getNumeric(dataType)

protected override def nullSafeEval(input: Any): Any = numeric.abs(input)
override protected def nullSafeEval(input: Any): Any = numeric.abs(input)
}

abstract class BinaryArithmetic extends BinaryOperator {
self: Product =>

override def dataType: DataType = left.dataType

override def checkInputDataTypes(): TypeCheckResult = {
if (left.dataType != right.dataType) {
TypeCheckResult.TypeCheckFailure(
s"differing types in ${this.getClass.getSimpleName} " +
s"(${left.dataType} and ${right.dataType}).")
} else {
checkTypesInternal(dataType)
}
}

protected def checkTypesInternal(t: DataType): TypeCheckResult

/** Name of the function for this expression on a [[Decimal]] type. */
def decimalMethod: String =
sys.error("BinaryArithmetics must override either decimalMethod or genCode")
Expand Down Expand Up @@ -110,12 +98,12 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) =
override protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

private lazy val numeric = TypeUtils.getNumeric(dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2)
override protected def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2)
}

case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
Expand All @@ -125,12 +113,12 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) =
override protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

private lazy val numeric = TypeUtils.getNumeric(dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2)
override protected def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2)
}

case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
Expand All @@ -140,12 +128,12 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) =
override protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

private lazy val numeric = TypeUtils.getNumeric(dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
override protected def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
}

case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
Expand All @@ -157,7 +145,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) =
override protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

private lazy val div: (Any, Any) => Any = dataType match {
Expand Down Expand Up @@ -223,7 +211,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) =
override protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

private lazy val integral = dataType match {
Expand Down Expand Up @@ -283,7 +271,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
override def nullable: Boolean = left.nullable && right.nullable

protected def checkTypesInternal(t: DataType) =
override protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForOrderingExpr(t, "function maxOf")

private lazy val ordering = TypeUtils.getOrdering(dataType)
Expand Down Expand Up @@ -337,7 +325,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
override def nullable: Boolean = left.nullable && right.nullable

protected def checkTypesInternal(t: DataType) =
override protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForOrderingExpr(t, "function minOf")

private lazy val ordering = TypeUtils.getOrdering(dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "&"

protected def checkTypesInternal(t: DataType) =
override protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)

private lazy val and: (Any, Any) => Any = dataType match {
Expand All @@ -45,7 +45,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any]
}

protected override def nullSafeEval(input1: Any, input2: Any): Any = and(input1, input2)
override protected def nullSafeEval(input1: Any, input2: Any): Any = and(input1, input2)
}

/**
Expand All @@ -56,7 +56,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "|"

protected def checkTypesInternal(t: DataType) =
override protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)

private lazy val or: (Any, Any) => Any = dataType match {
Expand All @@ -70,7 +70,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any]
}

protected override def nullSafeEval(input1: Any, input2: Any): Any = or(input1, input2)
override protected def nullSafeEval(input1: Any, input2: Any): Any = or(input1, input2)
}

/**
Expand All @@ -81,7 +81,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "^"

protected def checkTypesInternal(t: DataType) =
override protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)

private lazy val xor: (Any, Any) => Any = dataType match {
Expand All @@ -95,7 +95,7 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any]
}

protected override def nullSafeEval(input1: Any, input2: Any): Any = xor(input1, input2)
override protected def nullSafeEval(input1: Any, input2: Any): Any = xor(input1, input2)
}

/**
Expand All @@ -122,5 +122,5 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic {
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)")
}

protected override def nullSafeEval(input: Any): Any = not(input)
override protected def nullSafeEval(input: Any): Any = not(input)
}
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int)
override def dataType: DataType = field.dataType
override def nullable: Boolean = child.nullable || field.nullable

protected override def nullSafeEval(input: Any): Any =
override protected def nullSafeEval(input: Any): Any =
input.asInstanceOf[InternalRow](ordinal)

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
Expand Down Expand Up @@ -160,7 +160,7 @@ case class GetArrayStructFields(
override def dataType: DataType = ArrayType(field.dataType, containsNull)
override def nullable: Boolean = child.nullable || containsNull || field.nullable

protected override def nullSafeEval(input: Any): Any = {
override protected def nullSafeEval(input: Any): Any = {
input.asInstanceOf[Seq[InternalRow]].map { row =>
if (row == null) null else row(ordinal)
}
Expand Down Expand Up @@ -204,7 +204,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)

override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType

protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
override protected def nullSafeEval(value: Any, ordinal: Any): Any = {
// TODO: consider using Array[_] for ArrayType child to avoid
// boxing of primitives
val baseValue = value.asInstanceOf[Seq[_]]
Expand Down Expand Up @@ -248,7 +248,7 @@ case class GetMapValue(child: Expression, key: Expression)

override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType

protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
override protected def nullSafeEval(value: Any, ordinal: Any): Any = {
val baseValue = value.asInstanceOf[Map[Any, _]]
baseValue.get(ordinal).orNull
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
override def dataType: DataType = LongType
override def toString: String = s"UnscaledValue($child)"

protected override def nullSafeEval(input: Any): Any =
override protected def nullSafeEval(input: Any): Any =
input.asInstanceOf[Decimal].toUnscaledLong

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
Expand All @@ -48,7 +48,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
override def dataType: DataType = DecimalType(precision, scale)
override def toString: String = s"MakeDecimal($child,$precision,$scale)"

protected override def nullSafeEval(input: Any): Any =
override protected def nullSafeEval(input: Any): Any =
Decimal(input.asInstanceOf[Long], precision, scale)

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
override def nullable: Boolean = true
override def toString: String = s"$name($child)"

protected override def nullSafeEval(input: Any): Any = {
override protected def nullSafeEval(input: Any): Any = {
val result = f(input.asInstanceOf[Double])
if (result.isNaN) null else result
}
Expand Down Expand Up @@ -97,7 +97,7 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)

override def dataType: DataType = DoubleType

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
override protected def nullSafeEval(input1: Any, input2: Any): Any = {
val result = f(input1.asInstanceOf[Double], input2.asInstanceOf[Double])
if (result.isNaN) null else result
}
Expand Down Expand Up @@ -183,7 +183,7 @@ case class Factorial(child: Expression) extends UnaryExpression with ExpectsInpu
// If the value not in the range of [0, 20], it still will be null, so set it to be true here.
override def nullable: Boolean = true

protected override def nullSafeEval(input: Any): Any = {
override protected def nullSafeEval(input: Any): Any = {
val value = input.asInstanceOf[jl.Integer]
if (value > 20 || value < 0) {
null
Expand Down Expand Up @@ -256,7 +256,7 @@ case class Bin(child: Expression)
override def inputTypes: Seq[DataType] = Seq(LongType)
override def dataType: DataType = StringType

protected override def nullSafeEval(input: Any): Any =
override protected def nullSafeEval(input: Any): Any =
UTF8String.fromString(jl.Long.toBinaryString(input.asInstanceOf[Long]))

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
Expand Down Expand Up @@ -293,7 +293,7 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes

override def dataType: DataType = StringType

protected override def nullSafeEval(num: Any): Any = child.dataType match {
override protected def nullSafeEval(num: Any): Any = child.dataType match {
case LongType => hex(num.asInstanceOf[Long])
case BinaryType => hex(num.asInstanceOf[Array[Byte]])
case StringType => hex(num.asInstanceOf[UTF8String].getBytes)
Expand Down Expand Up @@ -337,7 +337,7 @@ case class Unhex(child: Expression) extends UnaryExpression with ExpectsInputTyp
override def nullable: Boolean = true
override def dataType: DataType = BinaryType

protected override def nullSafeEval(num: Any): Any =
override protected def nullSafeEval(num: Any): Any =
unhex(num.asInstanceOf[UTF8String].getBytes)

private[this] def unhex(bytes: Array[Byte]): Array[Byte] = {
Expand Down Expand Up @@ -383,7 +383,7 @@ case class Unhex(child: Expression) extends UnaryExpression with ExpectsInputTyp
case class Atan2(left: Expression, right: Expression)
extends BinaryMathExpression(math.atan2, "ATAN2") {

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
override protected def nullSafeEval(input1: Any, input2: Any): Any = {
// With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0
val result = math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0)
if (result.isNaN) null else result
Expand Down Expand Up @@ -423,7 +423,7 @@ case class ShiftLeft(left: Expression, right: Expression)

override def dataType: DataType = left.dataType

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
override protected def nullSafeEval(input1: Any, input2: Any): Any = {
input1 match {
case l: jl.Long => l << input2.asInstanceOf[jl.Integer]
case i: jl.Integer => i << input2.asInstanceOf[jl.Integer]
Expand All @@ -449,7 +449,7 @@ case class ShiftRight(left: Expression, right: Expression)

override def dataType: DataType = left.dataType

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
override protected def nullSafeEval(input1: Any, input2: Any): Any = {
input1 match {
case l: jl.Long => l >> input2.asInstanceOf[jl.Integer]
case i: jl.Integer => i >> input2.asInstanceOf[jl.Integer]
Expand All @@ -475,7 +475,7 @@ case class ShiftRightUnsigned(left: Expression, right: Expression)

override def dataType: DataType = left.dataType

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
override protected def nullSafeEval(input1: Any, input2: Any): Any = {
input1 match {
case l: jl.Long => l >>> input2.asInstanceOf[jl.Integer]
case i: jl.Integer => i >>> input2.asInstanceOf[jl.Integer]
Expand Down
Loading