From cb77e4f644441b137a9ab8ed2568ddce3bc0f053 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 27 May 2015 23:18:46 +0800 Subject: [PATCH 01/17] Improve error reporting for expression data type mismatch --- .../sql/catalyst/analysis/CheckAnalysis.scala | 10 +- .../catalyst/analysis/HiveTypeCoercion.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 12 +- .../sql/catalyst/expressions/arithmetic.scala | 323 ++++++++---------- .../expressions/mathfuncs/binary.scala | 7 - .../expressions/mathfuncs/unary.scala | 1 - .../sql/catalyst/expressions/predicates.scala | 113 +++--- .../spark/sql/catalyst/util/TypeUtils.scala | 56 +++ .../org/apache/spark/sql/types/DataType.scala | 2 +- 9 files changed, 276 insertions(+), 250 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 193dc6b6546b..bded3b664d8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -62,15 +62,15 @@ trait CheckAnalysis { val from = operator.inputSet.map(_.name).mkString(", ") a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") + case e: Expression if !e.validInputTypes => + e.failAnalysis( + s"cannot resolve '${t.prettyString}' due to data type mismatch: " + + e.typeMismatchErrorMessage.get) + case c: Cast if !c.resolved => failAnalysis( s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") - case b: BinaryExpression if !b.resolved => - failAnalysis( - s"invalid expression ${b.prettyString} " + - s"between ${b.left.dataType.simpleString} and ${b.right.dataType.simpleString}") - case WindowExpression(UnresolvedWindowFunction(name, _), _) => failAnalysis( s"Could not resolve window function '$name'. " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index edcc918bfe92..ba39a0ef2e2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -407,7 +407,7 @@ trait HiveTypeCoercion { Union(newLeft, newRight) // fix decimal precision for expressions - case q => q.transformExpressions { + case q => q.transformExpressionsUp { // Skip nodes whose children have not been resolved yet case e if !e.childrenResolved => e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d19928784442..56246a2bdc6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -86,12 +86,16 @@ abstract class Expression extends TreeNode[Expression] { case (i1, i2) => i1 == i2 } } + + def typeMismatchErrorMessage: Option[String] = None + + def validInputTypes: Boolean = typeMismatchErrorMessage.isEmpty } abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { self: Product => - def symbol: String + def symbol: String = sys.error(s"BinaryExpressions must either override toString or symbol") override def foldable: Boolean = left.foldable && right.foldable @@ -106,6 +110,10 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => + + override def foldable: Boolean = child.foldable + + override def nullable: Boolean = child.nullable } // TODO Semantically we probably not need GroupExpression @@ -125,7 +133,9 @@ case class GroupExpression(children: Seq[Expression]) extends Expression { * so that the proper type conversions can be performed in the analyzer. */ trait ExpectsInputTypes { + self: Expression => def expectedChildTypes: Seq[DataType] + override def validInputTypes: Boolean = children.map(_.dataType) == expectedChildTypes } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index f2299d5db6e9..e96ad4289d62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -17,72 +17,87 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -case class UnaryMinus(child: Expression) extends UnaryExpression { - - override def dataType: DataType = child.dataType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable - override def toString: String = s"-$child" - - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } +abstract class UnaryArithmetic extends UnaryExpression { + self: Product => override def eval(input: Row): Any = { val evalE = child.eval(input) if (evalE == null) { null } else { - numeric.negate(evalE) + evalInternal(evalE) } } + + protected def evalInternal(evalE: Any): Any = + sys.error(s"UnaryArithmetics must either override eval or evalInternal") } -case class Sqrt(child: Expression) extends UnaryExpression { +case class UnaryMinus(child: Expression) extends UnaryArithmetic { + override def dataType: DataType = child.dataType + override def toString: String = s"-$child" + + override def typeMismatchErrorMessage: Option[String] = { + TypeUtils.checkForNumericExpr(child.dataType, "todo") + } + + private lazy val numeric = TypeUtils.getNumeric(dataType) + protected override def evalInternal(evalE: Any) = numeric.negate(evalE) +} + +case class Sqrt(child: Expression) extends UnaryArithmetic { override def dataType: DataType = DoubleType - override def foldable: Boolean = child.foldable override def nullable: Boolean = true override def toString: String = s"SQRT($child)" - lazy val numeric = child.dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support non-negative numeric operations") + override def typeMismatchErrorMessage: Option[String] = { + TypeUtils.checkForNumericExpr(child.dataType, "todo") } - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - val value = numeric.toDouble(evalE) - if (value < 0) null - else math.sqrt(value) - } + private lazy val numeric = TypeUtils.getNumeric(child.dataType) + + protected override def evalInternal(evalE: Any) = { + val value = numeric.toDouble(evalE) + if (value < 0) null + else math.sqrt(value) + } +} + +/** + * A function that get the absolute value of the numeric value. + */ +case class Abs(child: Expression) extends UnaryArithmetic { + override def dataType: DataType = child.dataType + override def toString: String = s"Abs($child)" + + override def typeMismatchErrorMessage: Option[String] = { + TypeUtils.checkForNumericExpr(child.dataType, "todo") } + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE: Any) = numeric.abs(evalE) } abstract class BinaryArithmetic extends BinaryExpression { self: Product => - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType && - !DecimalType.isFixed(left.dataType) + override def dataType: DataType = left.dataType - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + override def typeMismatchErrorMessage: Option[String] = { + if (left.dataType != right.dataType) { + Some(s"differing types in BinaryArithmetics, ${left.dataType}, ${right.dataType}") + } else { + errorMessageInternal(left.dataType) } - left.dataType } + protected def errorMessageInternal(t: DataType): Option[String] + override def eval(input: Row): Any = { val evalE1 = left.eval(input) if(evalE1 == null) { @@ -97,88 +112,84 @@ abstract class BinaryArithmetic extends BinaryExpression { } } - def evalInternal(evalE1: Any, evalE2: Any): Any = - sys.error(s"BinaryExpressions must either override eval or evalInternal") + protected def evalInternal(evalE1: Any, evalE2: Any): Any = + sys.error(s"BinaryArithmetics must either override eval or evalInternal") } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } - - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null + protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForNumericExpr(t, "todo").orElse { + if (DecimalType.isFixed(t)) { + Some("todo") } else { - numeric.plus(evalE1, evalE2) + None } } } + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = + numeric.plus(evalE1, evalE2) } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "-" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } - - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null + protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForNumericExpr(t, "todo").orElse { + if (DecimalType.isFixed(t)) { + Some("todo") } else { - numeric.minus(evalE1, evalE2) + None } } } + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = + numeric.minus(evalE1, evalE2) } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "*" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } - - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null + protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForNumericExpr(t, "todo").orElse { + if (DecimalType.isFixed(t)) { + Some("todo") } else { - numeric.times(evalE1, evalE2) + None } } } + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = + numeric.times(evalE1, evalE2) } case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "/" - override def nullable: Boolean = true - lazy val div: (Any, Any) => Any = dataType match { + protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForNumericExpr(t, "todo").orElse { + if (DecimalType.isFixed(t)) { + Some("todo") + } else { + None + } + } + } + + private lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot - case other => sys.error(s"Type $other does not support numeric operations") } override def eval(input: Row): Any = { @@ -198,13 +209,21 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "%" - override def nullable: Boolean = true - lazy val integral = dataType match { + protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForNumericExpr(t, "todo").orElse { + if (DecimalType.isFixed(t)) { + Some("todo") + } else { + None + } + } + } + + private lazy val integral = dataType match { case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] - case other => sys.error(s"Type $other does not support numeric operations") } override def eval(input: Row): Any = { @@ -228,7 +247,11 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "&" - lazy val and: (Any, Any) => Any = dataType match { + protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForBitwiseExpr(t, "todo") + } + + private lazy val and: (Any, Any) => Any = dataType match { case ByteType => ((evalE1: Byte, evalE2: Byte) => (evalE1 & evalE2).toByte).asInstanceOf[(Any, Any) => Any] case ShortType => @@ -237,10 +260,10 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme ((evalE1: Int, evalE2: Int) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] case LongType => ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise & operation on $other") } - override def evalInternal(evalE1: Any, evalE2: Any): Any = and(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any) = + and(evalE1, evalE2) } /** @@ -249,7 +272,11 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "|" - lazy val or: (Any, Any) => Any = dataType match { + protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForBitwiseExpr(t, "todo") + } + + private lazy val or: (Any, Any) => Any = dataType match { case ByteType => ((evalE1: Byte, evalE2: Byte) => (evalE1 | evalE2).toByte).asInstanceOf[(Any, Any) => Any] case ShortType => @@ -258,10 +285,10 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet ((evalE1: Int, evalE2: Int) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] case LongType => ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise | operation on $other") } - override def evalInternal(evalE1: Any, evalE2: Any): Any = or(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any) = + or(evalE1, evalE2) } /** @@ -270,7 +297,11 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "^" - lazy val xor: (Any, Any) => Any = dataType match { + protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForBitwiseExpr(t, "todo") + } + + private lazy val xor: (Any, Any) => Any = dataType match { case ByteType => ((evalE1: Byte, evalE2: Byte) => (evalE1 ^ evalE2).toByte).asInstanceOf[(Any, Any) => Any] case ShortType => @@ -279,23 +310,24 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme ((evalE1: Int, evalE2: Int) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] case LongType => ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise ^ operation on $other") } - override def evalInternal(evalE1: Any, evalE2: Any): Any = xor(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any): Any = + xor(evalE1, evalE2) } /** * A function that calculates bitwise not(~) of a number. */ -case class BitwiseNot(child: Expression) extends UnaryExpression { - +case class BitwiseNot(child: Expression) extends UnaryArithmetic { override def dataType: DataType = child.dataType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def toString: String = s"~$child" - lazy val not: (Any) => Any = dataType match { + override def typeMismatchErrorMessage: Option[String] = { + TypeUtils.checkForBitwiseExpr(child.dataType, "todo") + } + + private lazy val not: (Any) => Any = dataType match { case ByteType => ((evalE: Byte) => (~evalE).toByte).asInstanceOf[(Any) => Any] case ShortType => @@ -304,42 +336,23 @@ case class BitwiseNot(child: Expression) extends UnaryExpression { ((evalE: Int) => ~evalE).asInstanceOf[(Any) => Any] case LongType => ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] - case other => sys.error(s"Unsupported bitwise ~ operation on $other") } - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - not(evalE) - } - } + protected override def evalInternal(evalE: Any) = not(evalE) } -case class MaxOf(left: Expression, right: Expression) extends Expression { - - override def foldable: Boolean = left.foldable && right.foldable - +case class MaxOf(left: Expression, right: Expression) extends BinaryExpression { override def nullable: Boolean = left.nullable && right.nullable + override def dataType: DataType = left.dataType - override def children: Seq[Expression] = left :: right :: Nil + private lazy val ordering = TypeUtils.getOrdering(dataType) - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType - - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + override def typeMismatchErrorMessage: Option[String] = { + if (left.dataType != right.dataType) { + Some(s"differing types in BinaryArithmetics, ${left.dataType}, ${right.dataType}") + } else { + TypeUtils.checkForOrderingExpr(dataType, "todo") } - left.dataType - } - - lazy val ordering = left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") } override def eval(input: Row): Any = { @@ -361,29 +374,18 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { override def toString: String = s"MaxOf($left, $right)" } -case class MinOf(left: Expression, right: Expression) extends Expression { - - override def foldable: Boolean = left.foldable && right.foldable - +case class MinOf(left: Expression, right: Expression) extends BinaryExpression { override def nullable: Boolean = left.nullable && right.nullable + override def dataType: DataType = left.dataType - override def children: Seq[Expression] = left :: right :: Nil + private lazy val ordering = TypeUtils.getOrdering(dataType) - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType - - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + override def typeMismatchErrorMessage: Option[String] = { + if (left.dataType != right.dataType) { + Some(s"differing types in BinaryArithmetics, ${left.dataType}, ${right.dataType}") + } else { + TypeUtils.checkForOrderingExpr(dataType, "todo") } - left.dataType - } - - lazy val ordering = left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") } override def eval(input: Row): Any = { @@ -404,28 +406,3 @@ case class MinOf(left: Expression, right: Expression) extends Expression { override def toString: String = s"MinOf($left, $right)" } - -/** - * A function that get the absolute value of the numeric value. - */ -case class Abs(child: Expression) extends UnaryExpression { - - override def dataType: DataType = child.dataType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable - override def toString: String = s"Abs($child)" - - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } - - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - numeric.abs(evalE) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala index 01f62ba0442e..2b2a994f843c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -29,17 +29,10 @@ import org.apache.spark.sql.types._ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => - override def symbol: String = null override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) - override def nullable: Boolean = left.nullable || right.nullable override def toString: String = s"$name($left, $right)" - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType && - !DecimalType.isFixed(left.dataType) - override def dataType: DataType = DoubleType override def eval(input: Row): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala index 41b422346a02..ff235f0ab58e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -31,7 +31,6 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType - override def foldable: Boolean = child.foldable override def nullable: Boolean = true override def toString: String = s"$name($child)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 4f422d69c438..03af95c508fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, AtomicType} +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType} object InterpretedPredicate { def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -70,8 +69,6 @@ trait PredicateHelper { case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes { - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def toString: String = s"NOT $child" override def expectedChildTypes: Seq[DataType] = Seq(BooleanType) @@ -171,6 +168,16 @@ case class Or(left: Expression, right: Expression) abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => + + override def typeMismatchErrorMessage: Option[String] = { + if (left.dataType != right.dataType) { + Some(s"differing types in BinaryComparisons, ${left.dataType}, ${right.dataType}") + } else { + errorMessageInternal(left.dataType) + } + } + + protected def errorMessageInternal(t: DataType): Option[String] = None } case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { @@ -210,17 +217,12 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp case class LessThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } + override protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForOrderingExpr(t, "todo") } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + override def eval(input: Row): Any = { val evalE1 = left.eval(input) if (evalE1 == null) { @@ -239,17 +241,12 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<=" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } + override protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForOrderingExpr(t, "todo") } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + override def eval(input: Row): Any = { val evalE1 = left.eval(input) if (evalE1 == null) { @@ -268,17 +265,12 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } + override protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForOrderingExpr(t, "todo") } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + override def eval(input: Row): Any = { val evalE1 = left.eval(input) if(evalE1 == null) { @@ -297,17 +289,12 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">=" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } + override protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForOrderingExpr(t, "todo") } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + override def eval(input: Row): Any = { val evalE1 = left.eval(input) if (evalE1 == null) { @@ -329,16 +316,16 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil override def nullable: Boolean = trueValue.nullable || falseValue.nullable - override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException( - this, - s"Can not resolve due to differing types ${trueValue.dataType}, ${falseValue.dataType}") + override def typeMismatchErrorMessage: Option[String] = { + if (trueValue.dataType != falseValue.dataType) { + Some(s"differing types in If, ${trueValue.dataType}, ${falseValue.dataType}") + } else { + None } - trueValue.dataType } + override def dataType: DataType = trueValue.dataType + override def eval(input: Row): Any = { if (true == predicate.eval(input)) { trueValue.eval(input) @@ -368,12 +355,7 @@ trait CaseWhenLike extends Expression { def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) def valueTypesEqual: Boolean = valueTypes.distinct.size == 1 - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") - } - valueTypes.head - } + override def dataType: DataType = valueTypes.head override def nullable: Boolean = { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. @@ -395,10 +377,15 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { override def children: Seq[Expression] = branches - override lazy val resolved: Boolean = - childrenResolved && - whenList.forall(_.dataType == BooleanType) && - valueTypesEqual + override def typeMismatchErrorMessage: Option[String] = { + if (!whenList.forall(_.dataType == BooleanType)) { + Some(s"WHEN expressions should all be boolean type") + } else if (!valueTypesEqual) { + Some("THEN and ELSE expressions should all be same type") + } else { + None + } + } /** Written in imperative fashion for performance considerations. */ override def eval(input: Row): Any = { @@ -441,9 +428,13 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW override def children: Seq[Expression] = key +: branches - override lazy val resolved: Boolean = - childrenResolved && valueTypesEqual && - (key +: whenList).map(_.dataType).distinct.size == 1 + override def typeMismatchErrorMessage: Option[String] = { + if (!valueTypesEqual) { + Some("THEN and ELSE expressions should all be same type") + } else { + None + } + } /** Written in imperative fashion for performance considerations. */ override def eval(input: Row): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala new file mode 100644 index 000000000000..511d876e87e2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.types.{AtomicType, IntegralType, NumericType, DataType} + +/** + * Helper function to check valid data types + */ +object TypeUtils { + + def checkForNumericExpr(t: DataType, errorMsg: => String): Option[String] = { + if (t.isInstanceOf[NumericType]) { + None + } else { + Some(errorMsg) + } + } + + def checkForBitwiseExpr(t: DataType, errorMsg: => String): Option[String] = { + if (t.isInstanceOf[IntegralType]) { + None + } else { + Some(errorMsg) + } + } + + def checkForOrderingExpr(t: DataType, errorMsg: => String): Option[String] = { + if (t.isInstanceOf[AtomicType]) { + None + } else { + Some(errorMsg) + } + } + + def getNumeric(t: DataType): Numeric[Any] = + t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] + + def getOrdering(t: DataType): Ordering[Any] = + t.asInstanceOf[AtomicType].ordering.asInstanceOf[Ordering[Any]] +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 1ba3a2686639..74677ddfcad6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -107,7 +107,7 @@ protected[sql] abstract class AtomicType extends DataType { abstract class NumericType extends AtomicType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a - // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets + // type parameter and add a numeric annotation (i.e., [JvmType : Numeric]). This gets // desugared by the compiler into an argument to the objects constructor. This means there is no // longer an no argument constructor and thus the JVM cannot serialize the object anymore. private[sql] val numeric: Numeric[InternalType] From 7ae76b93e11f63cbf35919565ce0b0edbac06ed1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 28 May 2015 14:48:54 +0800 Subject: [PATCH 02/17] address comments --- .../sql/catalyst/analysis/CheckAnalysis.scala | 6 +- .../sql/catalyst/expressions/Expression.scala | 17 +- .../sql/catalyst/expressions/arithmetic.scala | 193 ++++++++++-------- .../expressions/mathfuncs/binary.scala | 10 +- .../sql/catalyst/expressions/predicates.scala | 134 ++++++------ .../sql/catalyst/optimizer/Optimizer.scala | 4 + .../spark/sql/catalyst/util/TypeUtils.scala | 30 +-- 7 files changed, 187 insertions(+), 207 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index bded3b664d8c..7fd284bb315c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -62,10 +62,10 @@ trait CheckAnalysis { val from = operator.inputSet.map(_.name).mkString(", ") a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") - case e: Expression if !e.validInputTypes => + case e: Expression if e.checkInputDataTypes.isDefined => e.failAnalysis( - s"cannot resolve '${t.prettyString}' due to data type mismatch: " + - e.typeMismatchErrorMessage.get) + s"cannot resolve '${e.prettyString}' due to data type mismatch: " + + e.checkInputDataTypes.get) case c: Cast if !c.resolved => failAnalysis( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 56246a2bdc6b..ac71d3b02fff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -87,9 +87,10 @@ abstract class Expression extends TreeNode[Expression] { } } - def typeMismatchErrorMessage: Option[String] = None - - def validInputTypes: Boolean = typeMismatchErrorMessage.isEmpty + /** + * todo + */ + def checkInputDataTypes: Option[String] = None } abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { @@ -110,10 +111,6 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => - - override def foldable: Boolean = child.foldable - - override def nullable: Boolean = child.nullable } // TODO Semantically we probably not need GroupExpression @@ -137,5 +134,9 @@ trait ExpectsInputTypes { def expectedChildTypes: Seq[DataType] - override def validInputTypes: Boolean = children.map(_.dataType) == expectedChildTypes + override def checkInputDataTypes: Option[String] = { + // We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`, + // so type mismatch error won't be reported here, but for underling `Cast`s. + None + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index e96ad4289d62..9c2fe17de499 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -23,6 +23,18 @@ import org.apache.spark.sql.types._ abstract class UnaryArithmetic extends UnaryExpression { self: Product => + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def dataType: DataType = child.dataType + + override def checkInputDataTypes: Option[String] = { + if (TypeUtils.validForNumericExpr(child.dataType)) { + None + } else { + Some("todo") + } + } + override def eval(input: Row): Any = { val evalE = child.eval(input) if (evalE == null) { @@ -37,13 +49,8 @@ abstract class UnaryArithmetic extends UnaryExpression { } case class UnaryMinus(child: Expression) extends UnaryArithmetic { - override def dataType: DataType = child.dataType override def toString: String = s"-$child" - override def typeMismatchErrorMessage: Option[String] = { - TypeUtils.checkForNumericExpr(child.dataType, "todo") - } - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def evalInternal(evalE: Any) = numeric.negate(evalE) @@ -54,10 +61,6 @@ case class Sqrt(child: Expression) extends UnaryArithmetic { override def nullable: Boolean = true override def toString: String = s"SQRT($child)" - override def typeMismatchErrorMessage: Option[String] = { - TypeUtils.checkForNumericExpr(child.dataType, "todo") - } - private lazy val numeric = TypeUtils.getNumeric(child.dataType) protected override def evalInternal(evalE: Any) = { @@ -71,13 +74,8 @@ case class Sqrt(child: Expression) extends UnaryArithmetic { * A function that get the absolute value of the numeric value. */ case class Abs(child: Expression) extends UnaryArithmetic { - override def dataType: DataType = child.dataType override def toString: String = s"Abs($child)" - override def typeMismatchErrorMessage: Option[String] = { - TypeUtils.checkForNumericExpr(child.dataType, "todo") - } - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def evalInternal(evalE: Any) = numeric.abs(evalE) @@ -88,15 +86,15 @@ abstract class BinaryArithmetic extends BinaryExpression { override def dataType: DataType = left.dataType - override def typeMismatchErrorMessage: Option[String] = { + override def checkInputDataTypes: Option[String] = { if (left.dataType != right.dataType) { Some(s"differing types in BinaryArithmetics, ${left.dataType}, ${right.dataType}") } else { - errorMessageInternal(left.dataType) + checkTypesInternal(dataType) } } - protected def errorMessageInternal(t: DataType): Option[String] + protected def checkTypesInternal(t: DataType): Option[String] override def eval(input: Row): Any = { val evalE1 = left.eval(input) @@ -119,71 +117,76 @@ abstract class BinaryArithmetic extends BinaryExpression { case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" - protected def errorMessageInternal(t: DataType) = { - TypeUtils.checkForNumericExpr(t, "todo").orElse { - if (DecimalType.isFixed(t)) { - Some("todo") - } else { - None - } + // We will always cast fixed decimal to unlimited decimal + // for `Add` in `HiveTypeCoercion` + override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType) + + protected def checkTypesInternal(t: DataType) = { + if (TypeUtils.validForNumericExpr(t)) { + None + } else { + Some("todo") } } private lazy val numeric = TypeUtils.getNumeric(dataType) - protected override def evalInternal(evalE1: Any, evalE2: Any) = - numeric.plus(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.plus(evalE1, evalE2) } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "-" - protected def errorMessageInternal(t: DataType) = { - TypeUtils.checkForNumericExpr(t, "todo").orElse { - if (DecimalType.isFixed(t)) { - Some("todo") - } else { - None - } + // We will always cast fixed decimal to unlimited decimal + // for `Subtract` in `HiveTypeCoercion` + override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType) + + protected def checkTypesInternal(t: DataType) = { + if (TypeUtils.validForNumericExpr(t)) { + None + } else { + Some("todo") } } private lazy val numeric = TypeUtils.getNumeric(dataType) - protected override def evalInternal(evalE1: Any, evalE2: Any) = - numeric.minus(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.minus(evalE1, evalE2) } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "*" - protected def errorMessageInternal(t: DataType) = { - TypeUtils.checkForNumericExpr(t, "todo").orElse { - if (DecimalType.isFixed(t)) { - Some("todo") - } else { - None - } + // We will always cast fixed decimal to unlimited decimal + // for `Multiply` in `HiveTypeCoercion` + override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType) + + protected def checkTypesInternal(t: DataType) = { + if (TypeUtils.validForNumericExpr(t)) { + None + } else { + Some("todo") } } private lazy val numeric = TypeUtils.getNumeric(dataType) - protected override def evalInternal(evalE1: Any, evalE2: Any) = - numeric.times(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.times(evalE1, evalE2) } case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "/" override def nullable: Boolean = true - protected def errorMessageInternal(t: DataType) = { - TypeUtils.checkForNumericExpr(t, "todo").orElse { - if (DecimalType.isFixed(t)) { - Some("todo") - } else { - None - } + // We will always cast fixed decimal to unlimited decimal + // for `Divide` in `HiveTypeCoercion` + override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType) + + protected def checkTypesInternal(t: DataType) = { + if (TypeUtils.validForNumericExpr(t)) { + None + } else { + Some("todo") } } @@ -211,13 +214,15 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet override def symbol: String = "%" override def nullable: Boolean = true - protected def errorMessageInternal(t: DataType) = { - TypeUtils.checkForNumericExpr(t, "todo").orElse { - if (DecimalType.isFixed(t)) { - Some("todo") - } else { - None - } + // We will always cast fixed decimal to unlimited decimal + // for `Remainder` in `HiveTypeCoercion` + override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType) + + protected def checkTypesInternal(t: DataType) = { + if (TypeUtils.validForNumericExpr(t)) { + None + } else { + Some("todo") } } @@ -247,8 +252,12 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "&" - protected def errorMessageInternal(t: DataType) = { - TypeUtils.checkForBitwiseExpr(t, "todo") + protected def checkTypesInternal(t: DataType) = { + if (TypeUtils.validForBitwiseExpr(t)) { + None + } else { + Some("todo") + } } private lazy val and: (Any, Any) => Any = dataType match { @@ -262,8 +271,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] } - protected override def evalInternal(evalE1: Any, evalE2: Any) = - and(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any) = and(evalE1, evalE2) } /** @@ -272,8 +280,12 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "|" - protected def errorMessageInternal(t: DataType) = { - TypeUtils.checkForBitwiseExpr(t, "todo") + protected def checkTypesInternal(t: DataType) = { + if (TypeUtils.validForBitwiseExpr(t)) { + None + } else { + Some("todo") + } } private lazy val or: (Any, Any) => Any = dataType match { @@ -287,8 +299,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] } - protected override def evalInternal(evalE1: Any, evalE2: Any) = - or(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any) = or(evalE1, evalE2) } /** @@ -297,8 +308,12 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "^" - protected def errorMessageInternal(t: DataType) = { - TypeUtils.checkForBitwiseExpr(t, "todo") + protected def checkTypesInternal(t: DataType) = { + if (TypeUtils.validForBitwiseExpr(t)) { + None + } else { + Some("todo") + } } private lazy val xor: (Any, Any) => Any = dataType match { @@ -312,19 +327,21 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] } - protected override def evalInternal(evalE1: Any, evalE2: Any): Any = - xor(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any): Any = xor(evalE1, evalE2) } /** * A function that calculates bitwise not(~) of a number. */ case class BitwiseNot(child: Expression) extends UnaryArithmetic { - override def dataType: DataType = child.dataType override def toString: String = s"~$child" - override def typeMismatchErrorMessage: Option[String] = { - TypeUtils.checkForBitwiseExpr(child.dataType, "todo") + override def checkInputDataTypes: Option[String] = { + if (TypeUtils.validForBitwiseExpr(dataType)) { + None + } else { + Some("todo") + } } private lazy val not: (Any) => Any = dataType match { @@ -341,20 +358,19 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic { protected override def evalInternal(evalE: Any) = not(evalE) } -case class MaxOf(left: Expression, right: Expression) extends BinaryExpression { +case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - override def dataType: DataType = left.dataType - - private lazy val ordering = TypeUtils.getOrdering(dataType) - override def typeMismatchErrorMessage: Option[String] = { - if (left.dataType != right.dataType) { - Some(s"differing types in BinaryArithmetics, ${left.dataType}, ${right.dataType}") + protected def checkTypesInternal(t: DataType) = { + if (TypeUtils.validForOrderingExpr(t)) { + None } else { - TypeUtils.checkForOrderingExpr(dataType, "todo") + Some("todo") } } + private lazy val ordering = TypeUtils.getOrdering(dataType) + override def eval(input: Row): Any = { val evalE1 = left.eval(input) val evalE2 = right.eval(input) @@ -374,20 +390,19 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryExpression { override def toString: String = s"MaxOf($left, $right)" } -case class MinOf(left: Expression, right: Expression) extends BinaryExpression { +case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - override def dataType: DataType = left.dataType - - private lazy val ordering = TypeUtils.getOrdering(dataType) - override def typeMismatchErrorMessage: Option[String] = { - if (left.dataType != right.dataType) { - Some(s"differing types in BinaryArithmetics, ${left.dataType}, ${right.dataType}") + protected def checkTypesInternal(t: DataType) = { + if (TypeUtils.validForOrderingExpr(t)) { + None } else { - TypeUtils.checkForOrderingExpr(dataType, "todo") + Some("todo") } } + private lazy val ordering = TypeUtils.getOrdering(dataType) + override def eval(input: Row): Any = { val evalE1 = left.eval(input) val evalE2 = right.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala index 2b2a994f843c..db853a2b97fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -51,9 +51,8 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } } -case class Atan2( - left: Expression, - right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { +case class Atan2(left: Expression, right: Expression) + extends BinaryMathExpression(math.atan2, "ATAN2") { override def eval(input: Row): Any = { val evalE1 = left.eval(input) @@ -73,8 +72,7 @@ case class Atan2( } } -case class Hypot( - left: Expression, - right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") +case class Hypot(left: Expression, right: Expression) + extends BinaryMathExpression(math.hypot, "HYPOT") case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 03af95c508fe..161e5d909ddf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType} +import org.apache.spark.sql.types.{DecimalType, BinaryType, BooleanType, DataType} object InterpretedPredicate { def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -69,6 +69,8 @@ trait PredicateHelper { case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes { + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable override def toString: String = s"NOT $child" override def expectedChildTypes: Seq[DataType] = Seq(BooleanType) @@ -169,31 +171,41 @@ case class Or(left: Expression, right: Expression) abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => - override def typeMismatchErrorMessage: Option[String] = { + override def checkInputDataTypes: Option[String] = { if (left.dataType != right.dataType) { Some(s"differing types in BinaryComparisons, ${left.dataType}, ${right.dataType}") } else { - errorMessageInternal(left.dataType) + checkTypesInternal(left.dataType) } } - protected def errorMessageInternal(t: DataType): Option[String] = None -} - -case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "=" + protected def checkTypesInternal(t: DataType): Option[String] = None override def eval(input: Row): Any = { - val l = left.eval(input) - if (l == null) { + val evalE1 = left.eval(input) + if(evalE1 == null) { null } else { - val r = right.eval(input) - if (r == null) null - else if (left.dataType != BinaryType) l == r - else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + evalInternal(evalE1, evalE2) + } } } + + protected def evalInternal(evalE1: Any, evalE2: Any): Any = + sys.error(s"BinaryArithmetics must either override eval or evalInternal") +} + +case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { + override def symbol: String = "=" + + protected override def evalInternal(l: Any, r: Any) = { + if (left.dataType != BinaryType) l == r + else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) + } } case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { @@ -217,97 +229,65 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp case class LessThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<" - override protected def errorMessageInternal(t: DataType) = { - TypeUtils.checkForOrderingExpr(t, "todo") + override protected def checkTypesInternal(t: DataType) = { + if (TypeUtils.validForOrderingExpr(t)) { + None + } else { + Some("todo") + } } private lazy val ordering = TypeUtils.getOrdering(left.dataType) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.lt(evalE1, evalE2) - } - } - } + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lt(evalE1, evalE2) } case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<=" - override protected def errorMessageInternal(t: DataType) = { - TypeUtils.checkForOrderingExpr(t, "todo") + override protected def checkTypesInternal(t: DataType) = { + if (TypeUtils.validForOrderingExpr(t)) { + None + } else { + Some("todo") + } } private lazy val ordering = TypeUtils.getOrdering(left.dataType) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.lteq(evalE1, evalE2) - } - } - } + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lteq(evalE1, evalE2) } case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">" - override protected def errorMessageInternal(t: DataType) = { - TypeUtils.checkForOrderingExpr(t, "todo") + override protected def checkTypesInternal(t: DataType) = { + if (TypeUtils.validForOrderingExpr(t)) { + None + } else { + Some("todo") + } } private lazy val ordering = TypeUtils.getOrdering(left.dataType) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.gt(evalE1, evalE2) - } - } - } + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gt(evalE1, evalE2) } case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">=" - override protected def errorMessageInternal(t: DataType) = { - TypeUtils.checkForOrderingExpr(t, "todo") + override protected def checkTypesInternal(t: DataType) = { + if (TypeUtils.validForOrderingExpr(t)) { + None + } else { + Some("todo") + } } private lazy val ordering = TypeUtils.getOrdering(left.dataType) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.gteq(evalE1, evalE2) - } - } - } + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gteq(evalE1, evalE2) } case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) @@ -316,7 +296,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil override def nullable: Boolean = trueValue.nullable || falseValue.nullable - override def typeMismatchErrorMessage: Option[String] = { + override def checkInputDataTypes: Option[String] = { if (trueValue.dataType != falseValue.dataType) { Some(s"differing types in If, ${trueValue.dataType}, ${falseValue.dataType}") } else { @@ -377,7 +357,7 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { override def children: Seq[Expression] = branches - override def typeMismatchErrorMessage: Option[String] = { + override def checkInputDataTypes: Option[String] = { if (!whenList.forall(_.dataType == BooleanType)) { Some(s"WHEN expressions should all be boolean type") } else if (!valueTypesEqual) { @@ -428,7 +408,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW override def children: Seq[Expression] = key +: branches - override def typeMismatchErrorMessage: Option[String] = { + override def checkInputDataTypes: Option[String] = { if (!valueTypesEqual) { Some("THEN and ELSE expressions should all be same type") } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c2818d957cc7..3d62912cbdee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -264,6 +264,10 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType) case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType) + // MaxOf and MinOf can't do null propagation + case e: MaxOf => e + case e: MinOf => e + // Put exceptional cases above if any case e: BinaryArithmetic => e.children match { case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 511d876e87e2..d8c1ee4864d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -17,36 +17,18 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.types.{AtomicType, IntegralType, NumericType, DataType} +import org.apache.spark.sql.types._ /** * Helper function to check valid data types */ object TypeUtils { - def checkForNumericExpr(t: DataType, errorMsg: => String): Option[String] = { - if (t.isInstanceOf[NumericType]) { - None - } else { - Some(errorMsg) - } - } - - def checkForBitwiseExpr(t: DataType, errorMsg: => String): Option[String] = { - if (t.isInstanceOf[IntegralType]) { - None - } else { - Some(errorMsg) - } - } - - def checkForOrderingExpr(t: DataType, errorMsg: => String): Option[String] = { - if (t.isInstanceOf[AtomicType]) { - None - } else { - Some(errorMsg) - } - } + def validForNumericExpr(t: DataType): Boolean = t.isInstanceOf[NumericType] || t == NullType + + def validForBitwiseExpr(t: DataType): Boolean = t.isInstanceOf[IntegralType] || t == NullType + + def validForOrderingExpr(t: DataType): Boolean = t.isInstanceOf[AtomicType] || t == NullType def getNumeric(t: DataType): Numeric[Any] = t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] From 64917213bfc0251fed3ccac36e7ec7b2689d08b9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 28 May 2015 16:10:37 +0800 Subject: [PATCH 03/17] use value class TypeCheckResult --- .../sql/catalyst/analysis/CheckAnalysis.scala | 4 +- .../catalyst/analysis/TypeCheckResult.scala | 30 ++++++++++ .../sql/catalyst/expressions/Expression.scala | 8 +-- .../sql/catalyst/expressions/arithmetic.scala | 60 ++++++++++--------- .../sql/catalyst/expressions/predicates.scala | 45 +++++++------- 5 files changed, 91 insertions(+), 56 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7fd284bb315c..5b689f22bedb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -62,10 +62,10 @@ trait CheckAnalysis { val from = operator.inputSet.map(_.name).mkString(", ") a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") - case e: Expression if e.checkInputDataTypes.isDefined => + case e: Expression if e.checkInputDataTypes.hasError => e.failAnalysis( s"cannot resolve '${e.prettyString}' due to data type mismatch: " + - e.checkInputDataTypes.get) + e.checkInputDataTypes.errorMessage) case c: Cast if !c.resolved => failAnalysis( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala new file mode 100644 index 000000000000..1020fc400187 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +/** + * todo + */ +class TypeCheckResult(val errorMessage: String) extends AnyVal { + def hasError: Boolean = errorMessage != null +} + +object TypeCheckResult { + val success: TypeCheckResult = new TypeCheckResult(null) + def fail(msg: String): TypeCheckResult = new TypeCheckResult(msg) +} \ No newline at end of file diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index ac71d3b02fff..65c5212d407d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ @@ -90,7 +90,7 @@ abstract class Expression extends TreeNode[Expression] { /** * todo */ - def checkInputDataTypes: Option[String] = None + def checkInputDataTypes: TypeCheckResult = TypeCheckResult.success } abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { @@ -134,9 +134,9 @@ trait ExpectsInputTypes { def expectedChildTypes: Seq[DataType] - override def checkInputDataTypes: Option[String] = { + override def checkInputDataTypes: TypeCheckResult = { // We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`, // so type mismatch error won't be reported here, but for underling `Cast`s. - None + TypeCheckResult.success } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 9c2fe17de499..4e02c7a28e63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -27,11 +28,11 @@ abstract class UnaryArithmetic extends UnaryExpression { override def nullable: Boolean = child.nullable override def dataType: DataType = child.dataType - override def checkInputDataTypes: Option[String] = { + override def checkInputDataTypes: TypeCheckResult = { if (TypeUtils.validForNumericExpr(child.dataType)) { - None + TypeCheckResult.success } else { - Some("todo") + TypeCheckResult.fail("todo") } } @@ -86,15 +87,16 @@ abstract class BinaryArithmetic extends BinaryExpression { override def dataType: DataType = left.dataType - override def checkInputDataTypes: Option[String] = { + override def checkInputDataTypes: TypeCheckResult = { if (left.dataType != right.dataType) { - Some(s"differing types in BinaryArithmetics, ${left.dataType}, ${right.dataType}") + TypeCheckResult.fail( + s"differing types in BinaryArithmetics -- ${left.dataType}, ${right.dataType}") } else { checkTypesInternal(dataType) } } - protected def checkTypesInternal(t: DataType): Option[String] + protected def checkTypesInternal(t: DataType): TypeCheckResult override def eval(input: Row): Any = { val evalE1 = left.eval(input) @@ -123,9 +125,9 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { protected def checkTypesInternal(t: DataType) = { if (TypeUtils.validForNumericExpr(t)) { - None + TypeCheckResult.success } else { - Some("todo") + TypeCheckResult.fail("todo") } } @@ -143,9 +145,9 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti protected def checkTypesInternal(t: DataType) = { if (TypeUtils.validForNumericExpr(t)) { - None + TypeCheckResult.success } else { - Some("todo") + TypeCheckResult.fail("todo") } } @@ -163,9 +165,9 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti protected def checkTypesInternal(t: DataType) = { if (TypeUtils.validForNumericExpr(t)) { - None + TypeCheckResult.success } else { - Some("todo") + TypeCheckResult.fail("todo") } } @@ -184,9 +186,9 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic protected def checkTypesInternal(t: DataType) = { if (TypeUtils.validForNumericExpr(t)) { - None + TypeCheckResult.success } else { - Some("todo") + TypeCheckResult.fail("todo") } } @@ -220,9 +222,9 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet protected def checkTypesInternal(t: DataType) = { if (TypeUtils.validForNumericExpr(t)) { - None + TypeCheckResult.success } else { - Some("todo") + TypeCheckResult.fail("todo") } } @@ -254,9 +256,9 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme protected def checkTypesInternal(t: DataType) = { if (TypeUtils.validForBitwiseExpr(t)) { - None + TypeCheckResult.success } else { - Some("todo") + TypeCheckResult.fail("todo") } } @@ -282,9 +284,9 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet protected def checkTypesInternal(t: DataType) = { if (TypeUtils.validForBitwiseExpr(t)) { - None + TypeCheckResult.success } else { - Some("todo") + TypeCheckResult.fail("todo") } } @@ -310,9 +312,9 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme protected def checkTypesInternal(t: DataType) = { if (TypeUtils.validForBitwiseExpr(t)) { - None + TypeCheckResult.success } else { - Some("todo") + TypeCheckResult.fail("todo") } } @@ -336,11 +338,11 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme case class BitwiseNot(child: Expression) extends UnaryArithmetic { override def toString: String = s"~$child" - override def checkInputDataTypes: Option[String] = { + override def checkInputDataTypes: TypeCheckResult = { if (TypeUtils.validForBitwiseExpr(dataType)) { - None + TypeCheckResult.success } else { - Some("todo") + TypeCheckResult.fail("todo") } } @@ -363,9 +365,9 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { protected def checkTypesInternal(t: DataType) = { if (TypeUtils.validForOrderingExpr(t)) { - None + TypeCheckResult.success } else { - Some("todo") + TypeCheckResult.fail("todo") } } @@ -395,9 +397,9 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { protected def checkTypesInternal(t: DataType) = { if (TypeUtils.validForOrderingExpr(t)) { - None + TypeCheckResult.success } else { - Some("todo") + TypeCheckResult.fail("todo") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 161e5d909ddf..435946732ec4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types.{DecimalType, BinaryType, BooleanType, DataType} @@ -171,15 +172,16 @@ case class Or(left: Expression, right: Expression) abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => - override def checkInputDataTypes: Option[String] = { + override def checkInputDataTypes: TypeCheckResult = { if (left.dataType != right.dataType) { - Some(s"differing types in BinaryComparisons, ${left.dataType}, ${right.dataType}") + TypeCheckResult.fail( + s"differing types in BinaryComparisons -- ${left.dataType}, ${right.dataType}") } else { checkTypesInternal(left.dataType) } } - protected def checkTypesInternal(t: DataType): Option[String] = None + protected def checkTypesInternal(t: DataType): TypeCheckResult = TypeCheckResult.success override def eval(input: Row): Any = { val evalE1 = left.eval(input) @@ -231,9 +233,9 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso override protected def checkTypesInternal(t: DataType) = { if (TypeUtils.validForOrderingExpr(t)) { - None + TypeCheckResult.success } else { - Some("todo") + TypeCheckResult.fail("todo") } } @@ -247,9 +249,9 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo override protected def checkTypesInternal(t: DataType) = { if (TypeUtils.validForOrderingExpr(t)) { - None + TypeCheckResult.success } else { - Some("todo") + TypeCheckResult.fail("todo") } } @@ -263,9 +265,9 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar override protected def checkTypesInternal(t: DataType) = { if (TypeUtils.validForOrderingExpr(t)) { - None + TypeCheckResult.success } else { - Some("todo") + TypeCheckResult.fail("todo") } } @@ -279,9 +281,9 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar override protected def checkTypesInternal(t: DataType) = { if (TypeUtils.validForOrderingExpr(t)) { - None + TypeCheckResult.success } else { - Some("todo") + TypeCheckResult.fail("todo") } } @@ -296,11 +298,12 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil override def nullable: Boolean = trueValue.nullable || falseValue.nullable - override def checkInputDataTypes: Option[String] = { + override def checkInputDataTypes: TypeCheckResult = { if (trueValue.dataType != falseValue.dataType) { - Some(s"differing types in If, ${trueValue.dataType}, ${falseValue.dataType}") + TypeCheckResult.fail( + s"differing types in If, ${trueValue.dataType}, ${falseValue.dataType}") } else { - None + TypeCheckResult.success } } @@ -357,13 +360,13 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { override def children: Seq[Expression] = branches - override def checkInputDataTypes: Option[String] = { + override def checkInputDataTypes: TypeCheckResult = { if (!whenList.forall(_.dataType == BooleanType)) { - Some(s"WHEN expressions should all be boolean type") + TypeCheckResult.fail(s"WHEN expressions should all be boolean type") } else if (!valueTypesEqual) { - Some("THEN and ELSE expressions should all be same type") + TypeCheckResult.fail("THEN and ELSE expressions should all be same type") } else { - None + TypeCheckResult.success } } @@ -408,11 +411,11 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW override def children: Seq[Expression] = key +: branches - override def checkInputDataTypes: Option[String] = { + override def checkInputDataTypes: TypeCheckResult = { if (!valueTypesEqual) { - Some("THEN and ELSE expressions should all be same type") + TypeCheckResult.fail("THEN and ELSE expressions should all be same type") } else { - None + TypeCheckResult.success } } From c71d02cedd0dd60c24b1f5b8a6ab5eb22ba27477 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 28 May 2015 17:26:28 +0800 Subject: [PATCH 04/17] fix hive tests --- .../catalyst/analysis/TypeCheckResult.scala | 2 +- .../sql/catalyst/expressions/predicates.scala | 35 +++++++++---------- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala index 1020fc400187..165bbabd06d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala @@ -27,4 +27,4 @@ class TypeCheckResult(val errorMessage: String) extends AnyVal { object TypeCheckResult { val success: TypeCheckResult = new TypeCheckResult(null) def fail(msg: String): TypeCheckResult = new TypeCheckResult(msg) -} \ No newline at end of file +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 435946732ec4..2f4f44abf73a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -172,17 +172,6 @@ case class Or(left: Expression, right: Expression) abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => - override def checkInputDataTypes: TypeCheckResult = { - if (left.dataType != right.dataType) { - TypeCheckResult.fail( - s"differing types in BinaryComparisons -- ${left.dataType}, ${right.dataType}") - } else { - checkTypesInternal(left.dataType) - } - } - - protected def checkTypesInternal(t: DataType): TypeCheckResult = TypeCheckResult.success - override def eval(input: Row): Any = { val evalE1 = left.eval(input) if(evalE1 == null) { @@ -231,8 +220,10 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp case class LessThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<" - override protected def checkTypesInternal(t: DataType) = { - if (TypeUtils.validForOrderingExpr(t)) { + override def checkInputDataTypes: TypeCheckResult = { + if (left.dataType != right.dataType) { + TypeCheckResult.fail("types do not match -- ${left.dataType} != ${right.dataType}") + } else if (TypeUtils.validForOrderingExpr(left.dataType)) { TypeCheckResult.success } else { TypeCheckResult.fail("todo") @@ -247,8 +238,10 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<=" - override protected def checkTypesInternal(t: DataType) = { - if (TypeUtils.validForOrderingExpr(t)) { + override def checkInputDataTypes: TypeCheckResult = { + if (left.dataType != right.dataType) { + TypeCheckResult.fail("types do not match -- ${left.dataType} != ${right.dataType}") + } else if (TypeUtils.validForOrderingExpr(left.dataType)) { TypeCheckResult.success } else { TypeCheckResult.fail("todo") @@ -263,8 +256,10 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">" - override protected def checkTypesInternal(t: DataType) = { - if (TypeUtils.validForOrderingExpr(t)) { + override def checkInputDataTypes: TypeCheckResult = { + if (left.dataType != right.dataType) { + TypeCheckResult.fail("types do not match -- ${left.dataType} != ${right.dataType}") + } else if (TypeUtils.validForOrderingExpr(left.dataType)) { TypeCheckResult.success } else { TypeCheckResult.fail("todo") @@ -279,8 +274,10 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">=" - override protected def checkTypesInternal(t: DataType) = { - if (TypeUtils.validForOrderingExpr(t)) { + override def checkInputDataTypes: TypeCheckResult = { + if (left.dataType != right.dataType) { + TypeCheckResult.fail("types do not match -- ${left.dataType} != ${right.dataType}") + } else if (TypeUtils.validForOrderingExpr(left.dataType)) { TypeCheckResult.success } else { TypeCheckResult.fail("todo") From 69ca3feec42d54b3845c49d95b9890a0a40c4327 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 29 May 2015 13:14:55 +0800 Subject: [PATCH 05/17] add error message and tests --- .../catalyst/analysis/HiveTypeCoercion.scala | 9 +- .../sql/catalyst/expressions/Expression.scala | 5 +- .../sql/catalyst/expressions/arithmetic.scala | 120 +++++------------- .../expressions/mathfuncs/unary.scala | 1 + .../sql/catalyst/expressions/predicates.scala | 79 ++++-------- .../spark/sql/catalyst/util/DateUtils.scala | 2 +- .../spark/sql/catalyst/util/TypeUtils.scala | 32 ++++- .../ExpressionTypeCheckingSuite.scala | 117 +++++++++++++++++ 8 files changed, 210 insertions(+), 155 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index ba39a0ef2e2b..9d88a9bab789 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -619,18 +619,13 @@ trait HiveTypeCoercion { */ object Division extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e + // Skip nodes who's children have not been resolved yet or input types do not match. + case e if !e.childrenResolved || e.checkInputDataTypes().hasError => e // Decimal and Double remain the same case d: Divide if d.resolved && d.dataType == DoubleType => d case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d - case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] => - Divide(l, Cast(r, DecimalType.Unlimited)) - case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] => - Divide(Cast(l, DecimalType.Unlimited), r) - case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 65c5212d407d..93e3ffb0c162 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -88,9 +88,10 @@ abstract class Expression extends TreeNode[Expression] { } /** - * todo + * Check the input data types, returns `TypeCheckResult.success` if it's valid, + * or return a `TypeCheckResult` with an error message if invalid. */ - def checkInputDataTypes: TypeCheckResult = TypeCheckResult.success + def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.success } abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 4e02c7a28e63..7282a877f053 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -28,14 +28,6 @@ abstract class UnaryArithmetic extends UnaryExpression { override def nullable: Boolean = child.nullable override def dataType: DataType = child.dataType - override def checkInputDataTypes: TypeCheckResult = { - if (TypeUtils.validForNumericExpr(child.dataType)) { - TypeCheckResult.success - } else { - TypeCheckResult.fail("todo") - } - } - override def eval(input: Row): Any = { val evalE = child.eval(input) if (evalE == null) { @@ -52,6 +44,9 @@ abstract class UnaryArithmetic extends UnaryExpression { case class UnaryMinus(child: Expression) extends UnaryArithmetic { override def toString: String = s"-$child" + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "operator -") + private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def evalInternal(evalE: Any) = numeric.negate(evalE) @@ -62,6 +57,9 @@ case class Sqrt(child: Expression) extends UnaryArithmetic { override def nullable: Boolean = true override def toString: String = s"SQRT($child)" + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sqrt") + private lazy val numeric = TypeUtils.getNumeric(child.dataType) protected override def evalInternal(evalE: Any) = { @@ -77,6 +75,9 @@ case class Sqrt(child: Expression) extends UnaryArithmetic { case class Abs(child: Expression) extends UnaryArithmetic { override def toString: String = s"Abs($child)" + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function abs") + private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def evalInternal(evalE: Any) = numeric.abs(evalE) @@ -87,10 +88,10 @@ abstract class BinaryArithmetic extends BinaryExpression { override def dataType: DataType = left.dataType - override def checkInputDataTypes: TypeCheckResult = { + override def checkInputDataTypes(): TypeCheckResult = { if (left.dataType != right.dataType) { TypeCheckResult.fail( - s"differing types in BinaryArithmetics -- ${left.dataType}, ${right.dataType}") + s"differing types in BinaryArithmetic, ${left.dataType} != ${right.dataType}") } else { checkTypesInternal(dataType) } @@ -123,13 +124,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { // for `Add` in `HiveTypeCoercion` override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = { - if (TypeUtils.validForNumericExpr(t)) { - TypeCheckResult.success - } else { - TypeCheckResult.fail("todo") - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) private lazy val numeric = TypeUtils.getNumeric(dataType) @@ -143,13 +139,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti // for `Subtract` in `HiveTypeCoercion` override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = { - if (TypeUtils.validForNumericExpr(t)) { - TypeCheckResult.success - } else { - TypeCheckResult.fail("todo") - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) private lazy val numeric = TypeUtils.getNumeric(dataType) @@ -163,13 +154,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti // for `Multiply` in `HiveTypeCoercion` override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = { - if (TypeUtils.validForNumericExpr(t)) { - TypeCheckResult.success - } else { - TypeCheckResult.fail("todo") - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) private lazy val numeric = TypeUtils.getNumeric(dataType) @@ -184,13 +170,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic // for `Divide` in `HiveTypeCoercion` override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = { - if (TypeUtils.validForNumericExpr(t)) { - TypeCheckResult.success - } else { - TypeCheckResult.fail("todo") - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) private lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div @@ -220,13 +201,8 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet // for `Remainder` in `HiveTypeCoercion` override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = { - if (TypeUtils.validForNumericExpr(t)) { - TypeCheckResult.success - } else { - TypeCheckResult.fail("todo") - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) private lazy val integral = dataType match { case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] @@ -254,13 +230,8 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "&" - protected def checkTypesInternal(t: DataType) = { - if (TypeUtils.validForBitwiseExpr(t)) { - TypeCheckResult.success - } else { - TypeCheckResult.fail("todo") - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) private lazy val and: (Any, Any) => Any = dataType match { case ByteType => @@ -282,13 +253,8 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "|" - protected def checkTypesInternal(t: DataType) = { - if (TypeUtils.validForBitwiseExpr(t)) { - TypeCheckResult.success - } else { - TypeCheckResult.fail("todo") - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) private lazy val or: (Any, Any) => Any = dataType match { case ByteType => @@ -310,13 +276,8 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "^" - protected def checkTypesInternal(t: DataType) = { - if (TypeUtils.validForBitwiseExpr(t)) { - TypeCheckResult.success - } else { - TypeCheckResult.fail("todo") - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) private lazy val xor: (Any, Any) => Any = dataType match { case ByteType => @@ -338,13 +299,8 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme case class BitwiseNot(child: Expression) extends UnaryArithmetic { override def toString: String = s"~$child" - override def checkInputDataTypes: TypeCheckResult = { - if (TypeUtils.validForBitwiseExpr(dataType)) { - TypeCheckResult.success - } else { - TypeCheckResult.fail("todo") - } - } + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~") private lazy val not: (Any) => Any = dataType match { case ByteType => @@ -363,13 +319,8 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic { case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - protected def checkTypesInternal(t: DataType) = { - if (TypeUtils.validForOrderingExpr(t)) { - TypeCheckResult.success - } else { - TypeCheckResult.fail("todo") - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(t, "function maxOf") private lazy val ordering = TypeUtils.getOrdering(dataType) @@ -395,13 +346,8 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - protected def checkTypesInternal(t: DataType) = { - if (TypeUtils.validForOrderingExpr(t)) { - TypeCheckResult.success - } else { - TypeCheckResult.fail("todo") - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(t, "function minOf") private lazy val ordering = TypeUtils.getOrdering(dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala index ff235f0ab58e..41b422346a02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -31,6 +31,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType + override def foldable: Boolean = child.foldable override def nullable: Boolean = true override def toString: String = s"$name($child)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2f4f44abf73a..1540e92abf83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.types.{DecimalType, BinaryType, BooleanType, DataType} +import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType} object InterpretedPredicate { def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -172,9 +172,18 @@ case class Or(left: Expression, right: Expression) abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => + override def checkInputDataTypes(): TypeCheckResult = { + if (left.dataType != right.dataType) { + TypeCheckResult.fail( + s"differing types in BinaryComparison, ${left.dataType} != ${right.dataType}") + } else { + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + } + } + override def eval(input: Row): Any = { val evalE1 = left.eval(input) - if(evalE1 == null) { + if (evalE1 == null) { null } else { val evalE2 = right.eval(input) @@ -187,12 +196,15 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { } protected def evalInternal(evalE1: Any, evalE2: Any): Any = - sys.error(s"BinaryArithmetics must either override eval or evalInternal") + sys.error(s"BinaryComparisons must either override eval or evalInternal") } case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "=" + // EqualTo don't need 2 equal orderable types + override def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.success + protected override def evalInternal(l: Any, r: Any) = { if (left.dataType != BinaryType) l == r else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) @@ -201,9 +213,11 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<=>" - override def nullable: Boolean = false + // EqualNullSafe don't need 2 equal orderable types + override def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.success + override def eval(input: Row): Any = { val l = left.eval(input) val r = right.eval(input) @@ -220,16 +234,6 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp case class LessThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<" - override def checkInputDataTypes: TypeCheckResult = { - if (left.dataType != right.dataType) { - TypeCheckResult.fail("types do not match -- ${left.dataType} != ${right.dataType}") - } else if (TypeUtils.validForOrderingExpr(left.dataType)) { - TypeCheckResult.success - } else { - TypeCheckResult.fail("todo") - } - } - private lazy val ordering = TypeUtils.getOrdering(left.dataType) protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lt(evalE1, evalE2) @@ -238,16 +242,6 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<=" - override def checkInputDataTypes: TypeCheckResult = { - if (left.dataType != right.dataType) { - TypeCheckResult.fail("types do not match -- ${left.dataType} != ${right.dataType}") - } else if (TypeUtils.validForOrderingExpr(left.dataType)) { - TypeCheckResult.success - } else { - TypeCheckResult.fail("todo") - } - } - private lazy val ordering = TypeUtils.getOrdering(left.dataType) protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lteq(evalE1, evalE2) @@ -256,16 +250,6 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">" - override def checkInputDataTypes: TypeCheckResult = { - if (left.dataType != right.dataType) { - TypeCheckResult.fail("types do not match -- ${left.dataType} != ${right.dataType}") - } else if (TypeUtils.validForOrderingExpr(left.dataType)) { - TypeCheckResult.success - } else { - TypeCheckResult.fail("todo") - } - } - private lazy val ordering = TypeUtils.getOrdering(left.dataType) protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gt(evalE1, evalE2) @@ -274,16 +258,6 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">=" - override def checkInputDataTypes: TypeCheckResult = { - if (left.dataType != right.dataType) { - TypeCheckResult.fail("types do not match -- ${left.dataType} != ${right.dataType}") - } else if (TypeUtils.validForOrderingExpr(left.dataType)) { - TypeCheckResult.success - } else { - TypeCheckResult.fail("todo") - } - } - private lazy val ordering = TypeUtils.getOrdering(left.dataType) protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gteq(evalE1, evalE2) @@ -295,10 +269,13 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil override def nullable: Boolean = trueValue.nullable || falseValue.nullable - override def checkInputDataTypes: TypeCheckResult = { - if (trueValue.dataType != falseValue.dataType) { + override def checkInputDataTypes(): TypeCheckResult = { + if (predicate.dataType != BooleanType) { + TypeCheckResult.fail( + s"type of predicate expression in If should be boolean, not ${predicate.dataType}") + } else if (trueValue.dataType != falseValue.dataType) { TypeCheckResult.fail( - s"differing types in If, ${trueValue.dataType}, ${falseValue.dataType}") + s"differing types in If, ${trueValue.dataType} != ${falseValue.dataType}") } else { TypeCheckResult.success } @@ -357,11 +334,11 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { override def children: Seq[Expression] = branches - override def checkInputDataTypes: TypeCheckResult = { + override def checkInputDataTypes(): TypeCheckResult = { if (!whenList.forall(_.dataType == BooleanType)) { TypeCheckResult.fail(s"WHEN expressions should all be boolean type") } else if (!valueTypesEqual) { - TypeCheckResult.fail("THEN and ELSE expressions should all be same type") + TypeCheckResult.fail("THEN and ELSE expressions should all be same type or coercible to a common type") } else { TypeCheckResult.success } @@ -408,9 +385,9 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW override def children: Seq[Expression] = key +: branches - override def checkInputDataTypes: TypeCheckResult = { + override def checkInputDataTypes(): TypeCheckResult = { if (!valueTypesEqual) { - TypeCheckResult.fail("THEN and ELSE expressions should all be same type") + TypeCheckResult.fail("THEN and ELSE expressions should all be same type or coercible to a common type") } else { TypeCheckResult.success } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala index 3f92be4a55d7..ad649acf536f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala @@ -24,7 +24,7 @@ import java.util.{Calendar, TimeZone} import org.apache.spark.sql.catalyst.expressions.Cast /** - * helper function to convert between Int value of days since 1970-01-01 and java.sql.Date + * Helper function to convert between Int value of days since 1970-01-01 and java.sql.Date */ object DateUtils { private val MILLIS_PER_DAY = 86400000 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index d8c1ee4864d3..81efb8b21e7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -17,18 +17,36 @@ package org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.types._ /** - * Helper function to check valid data types + * Helper function to check for valid data types */ object TypeUtils { - - def validForNumericExpr(t: DataType): Boolean = t.isInstanceOf[NumericType] || t == NullType - - def validForBitwiseExpr(t: DataType): Boolean = t.isInstanceOf[IntegralType] || t == NullType - - def validForOrderingExpr(t: DataType): Boolean = t.isInstanceOf[AtomicType] || t == NullType + def checkForNumericExpr(t: DataType, caller: String): TypeCheckResult = { + if (t.isInstanceOf[NumericType] || t == NullType) { + TypeCheckResult.success + } else { + TypeCheckResult.fail(s"$caller need numeric type(int, long, double, etc.), not $t") + } + } + + def checkForBitwiseExpr(t: DataType, caller: String): TypeCheckResult = { + if (t.isInstanceOf[IntegralType] || t == NullType) { + TypeCheckResult.success + } else { + TypeCheckResult.fail(s"$caller need integral type(short, int, long, etc.), not $t") + } + } + + def checkForOrderingExpr(t: DataType, caller: String): TypeCheckResult = { + if (t.isInstanceOf[AtomicType] || t == NullType) { + TypeCheckResult.success + } else { + TypeCheckResult.fail(s"$caller need atomic type(binary, boolean, numeric, etc), not $t") + } + } def getNumeric(t: DataType): Numeric[Any] = t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala new file mode 100644 index 000000000000..d8f3e4e4e254 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.{BooleanType, StringType} +import org.scalatest.FunSuite + + +class ExpressionTypeCheckingSuite extends FunSuite { + + val testRelation = LocalRelation('a.int, 'b.string, 'c.boolean, 'd.array(StringType)) + + def checkError(expr: Expression, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + checkAnalysis(expr) + } + assert(e.getMessage.contains(s"cannot resolve '${expr.prettyString}' due to data type mismatch:")) + assert(e.getMessage.contains(errorMessage)) + } + + def checkAnalysis(expr: Expression): Unit = { + val analyzed = testRelation.select(expr.as("_c")).analyze + SimpleAnalyzer.checkAnalysis(analyzed) + } + + test("check types for unary arithmetic") { + checkError(UnaryMinus('b), "operator - need numeric type") + checkAnalysis(Sqrt('b)) // We will cast String to Double for sqrt + checkError(Sqrt('c), "function sqrt need numeric type") + checkError(Abs('b), "function abs need numeric type") + checkError(BitwiseNot('b), "operator ~ need integral type") + } + + test("check types for binary arithmetic") { + // We will cast String to Double for binary arithmetic + checkAnalysis(Add('a, 'b)) + checkAnalysis(Subtract('a, 'b)) + checkAnalysis(Multiply('a, 'b)) + checkAnalysis(Divide('a, 'b)) + checkAnalysis(Remainder('a, 'b)) + //checkAnalysis(BitwiseAnd('a, 'b)) + + val msg = "differing types in BinaryArithmetic, IntegerType != BooleanType" + checkError(Add('a, 'c), msg) + checkError(Subtract('a, 'c), msg) + checkError(Multiply('a, 'c), msg) + checkError(Divide('a, 'c), msg) + checkError(Remainder('a, 'c), msg) + checkError(BitwiseAnd('a, 'c), msg) + checkError(BitwiseOr('a, 'c), msg) + checkError(BitwiseXor('a, 'c), msg) + checkError(MaxOf('a, 'c), msg) + checkError(MinOf('a, 'c), msg) + + checkError(Add('c, 'c), "operator + need numeric type") + checkError(Subtract('c, 'c), "operator - need numeric type") + checkError(Multiply('c, 'c), "operator * need numeric type") + checkError(Divide('c, 'c), "operator / need numeric type") + checkError(Remainder('c, 'c), "operator % need numeric type") + + checkError(BitwiseAnd('c, 'c), "operator & need integral type") + checkError(BitwiseOr('c, 'c), "operator | need integral type") + checkError(BitwiseXor('c, 'c), "operator ^ need integral type") + + checkError(MaxOf('d, 'd), "function maxOf need atomic type") + checkError(MinOf('d, 'd), "function minOf need atomic type") + } + + test("check types for predicates") { + // EqualTo don't have type constraint + checkAnalysis(EqualTo('a, 'c)) + checkAnalysis(EqualNullSafe('a, 'c)) + + // We will cast String to Double for binary comparison + checkAnalysis(LessThan('a, 'b)) + checkAnalysis(LessThanOrEqual('a, 'b)) + checkAnalysis(GreaterThan('a, 'b)) + checkAnalysis(GreaterThanOrEqual('a, 'b)) + + val msg = "differing types in BinaryComparison, IntegerType != BooleanType" + checkError(LessThan('a, 'c), msg) + checkError(LessThanOrEqual('a, 'c), msg) + checkError(GreaterThan('a, 'c), msg) + checkError(GreaterThanOrEqual('a, 'c), msg) + + checkError(LessThan('d, 'd), "operator < need atomic type") + checkError(LessThanOrEqual('d, 'd), "operator <= need atomic type") + checkError(GreaterThan('d, 'd), "operator > need atomic type") + checkError(GreaterThanOrEqual('d, 'd), "operator >= need atomic type") + + checkError(If('a, 'a, 'a), "type of predicate expression in If should be boolean") + checkError(If('c, 'a, 'b), "differing types in If, IntegerType != StringType") + + // Will write tests for CaseWhen later, + // as the error reporting of it is not handle by the new interface for now + } +} From 1524ff67e0bf3c0dc2e58c063dd71853bd7b0e81 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 29 May 2015 13:49:02 +0800 Subject: [PATCH 06/17] fix style --- .../apache/spark/sql/catalyst/expressions/predicates.scala | 6 ++++-- .../catalyst/expressions/ExpressionTypeCheckingSuite.scala | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 1540e92abf83..5f29036083e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -338,7 +338,8 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { if (!whenList.forall(_.dataType == BooleanType)) { TypeCheckResult.fail(s"WHEN expressions should all be boolean type") } else if (!valueTypesEqual) { - TypeCheckResult.fail("THEN and ELSE expressions should all be same type or coercible to a common type") + TypeCheckResult.fail( + "THEN and ELSE expressions should all be same type or coercible to a common type") } else { TypeCheckResult.success } @@ -387,7 +388,8 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW override def checkInputDataTypes(): TypeCheckResult = { if (!valueTypesEqual) { - TypeCheckResult.fail("THEN and ELSE expressions should all be same type or coercible to a common type") + TypeCheckResult.fail( + "THEN and ELSE expressions should all be same type or coercible to a common type") } else { TypeCheckResult.success } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala index d8f3e4e4e254..9fdefc5301b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala @@ -34,7 +34,8 @@ class ExpressionTypeCheckingSuite extends FunSuite { val e = intercept[AnalysisException] { checkAnalysis(expr) } - assert(e.getMessage.contains(s"cannot resolve '${expr.prettyString}' due to data type mismatch:")) + assert(e.getMessage.contains( + s"cannot resolve '${expr.prettyString}' due to data type mismatch:")) assert(e.getMessage.contains(errorMessage)) } @@ -58,7 +59,7 @@ class ExpressionTypeCheckingSuite extends FunSuite { checkAnalysis(Multiply('a, 'b)) checkAnalysis(Divide('a, 'b)) checkAnalysis(Remainder('a, 'b)) - //checkAnalysis(BitwiseAnd('a, 'b)) + // checkAnalysis(BitwiseAnd('a, 'b)) val msg = "differing types in BinaryArithmetic, IntegerType != BooleanType" checkError(Add('a, 'c), msg) From e0a3628a1052544beb395b26c1f5764836f9461d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 29 May 2015 13:52:46 +0800 Subject: [PATCH 07/17] improve error message --- .../spark/sql/catalyst/util/TypeUtils.scala | 6 ++-- .../ExpressionTypeCheckingSuite.scala | 36 +++++++++---------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 81efb8b21e7f..26df4fbfcf31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -28,7 +28,7 @@ object TypeUtils { if (t.isInstanceOf[NumericType] || t == NullType) { TypeCheckResult.success } else { - TypeCheckResult.fail(s"$caller need numeric type(int, long, double, etc.), not $t") + TypeCheckResult.fail(s"$caller accepts numeric types, not $t") } } @@ -36,7 +36,7 @@ object TypeUtils { if (t.isInstanceOf[IntegralType] || t == NullType) { TypeCheckResult.success } else { - TypeCheckResult.fail(s"$caller need integral type(short, int, long, etc.), not $t") + TypeCheckResult.fail(s"$caller accepts integral types, not $t") } } @@ -44,7 +44,7 @@ object TypeUtils { if (t.isInstanceOf[AtomicType] || t == NullType) { TypeCheckResult.success } else { - TypeCheckResult.fail(s"$caller need atomic type(binary, boolean, numeric, etc), not $t") + TypeCheckResult.fail(s"$caller accepts non-complex types, not $t") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala index 9fdefc5301b8..c241d05063ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala @@ -45,11 +45,11 @@ class ExpressionTypeCheckingSuite extends FunSuite { } test("check types for unary arithmetic") { - checkError(UnaryMinus('b), "operator - need numeric type") + checkError(UnaryMinus('b), "operator - accepts numeric type") checkAnalysis(Sqrt('b)) // We will cast String to Double for sqrt - checkError(Sqrt('c), "function sqrt need numeric type") - checkError(Abs('b), "function abs need numeric type") - checkError(BitwiseNot('b), "operator ~ need integral type") + checkError(Sqrt('c), "function sqrt accepts numeric type") + checkError(Abs('b), "function abs accepts numeric type") + checkError(BitwiseNot('b), "operator ~ accepts integral type") } test("check types for binary arithmetic") { @@ -73,18 +73,18 @@ class ExpressionTypeCheckingSuite extends FunSuite { checkError(MaxOf('a, 'c), msg) checkError(MinOf('a, 'c), msg) - checkError(Add('c, 'c), "operator + need numeric type") - checkError(Subtract('c, 'c), "operator - need numeric type") - checkError(Multiply('c, 'c), "operator * need numeric type") - checkError(Divide('c, 'c), "operator / need numeric type") - checkError(Remainder('c, 'c), "operator % need numeric type") + checkError(Add('c, 'c), "operator + accepts numeric type") + checkError(Subtract('c, 'c), "operator - accepts numeric type") + checkError(Multiply('c, 'c), "operator * accepts numeric type") + checkError(Divide('c, 'c), "operator / accepts numeric type") + checkError(Remainder('c, 'c), "operator % accepts numeric type") - checkError(BitwiseAnd('c, 'c), "operator & need integral type") - checkError(BitwiseOr('c, 'c), "operator | need integral type") - checkError(BitwiseXor('c, 'c), "operator ^ need integral type") + checkError(BitwiseAnd('c, 'c), "operator & accepts integral type") + checkError(BitwiseOr('c, 'c), "operator | accepts integral type") + checkError(BitwiseXor('c, 'c), "operator ^ accepts integral type") - checkError(MaxOf('d, 'd), "function maxOf need atomic type") - checkError(MinOf('d, 'd), "function minOf need atomic type") + checkError(MaxOf('d, 'd), "function maxOf accepts non-complex type") + checkError(MinOf('d, 'd), "function minOf accepts non-complex type") } test("check types for predicates") { @@ -104,10 +104,10 @@ class ExpressionTypeCheckingSuite extends FunSuite { checkError(GreaterThan('a, 'c), msg) checkError(GreaterThanOrEqual('a, 'c), msg) - checkError(LessThan('d, 'd), "operator < need atomic type") - checkError(LessThanOrEqual('d, 'd), "operator <= need atomic type") - checkError(GreaterThan('d, 'd), "operator > need atomic type") - checkError(GreaterThanOrEqual('d, 'd), "operator >= need atomic type") + checkError(LessThan('d, 'd), "operator < accepts non-complex type") + checkError(LessThanOrEqual('d, 'd), "operator <= accepts non-complex type") + checkError(GreaterThan('d, 'd), "operator > accepts non-complex type") + checkError(GreaterThanOrEqual('d, 'd), "operator >= accepts non-complex type") checkError(If('a, 'a, 'a), "type of predicate expression in If should be boolean") checkError(If('c, 'a, 'b), "differing types in If, IntegerType != StringType") From 654d46abd2e5a988775edd0c50c395be63a4163f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 29 May 2015 14:13:55 +0800 Subject: [PATCH 08/17] improve tests --- .../sql/catalyst/expressions/arithmetic.scala | 2 +- .../sql/catalyst/expressions/predicates.scala | 2 +- .../ExpressionTypeCheckingSuite.scala | 126 +++++++++--------- 3 files changed, 67 insertions(+), 63 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 7282a877f053..72dc8cc86679 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -91,7 +91,7 @@ abstract class BinaryArithmetic extends BinaryExpression { override def checkInputDataTypes(): TypeCheckResult = { if (left.dataType != right.dataType) { TypeCheckResult.fail( - s"differing types in BinaryArithmetic, ${left.dataType} != ${right.dataType}") + s"differing types in ${this.getClass.getSimpleName}, ${left.dataType} != ${right.dataType}") } else { checkTypesInternal(dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 5f29036083e3..874283c0f5a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -175,7 +175,7 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { override def checkInputDataTypes(): TypeCheckResult = { if (left.dataType != right.dataType) { TypeCheckResult.fail( - s"differing types in BinaryComparison, ${left.dataType} != ${right.dataType}") + s"differing types in ${this.getClass.getSimpleName}, ${left.dataType} != ${right.dataType}") } else { TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala index c241d05063ef..c9481a5f9625 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala @@ -28,89 +28,93 @@ import org.scalatest.FunSuite class ExpressionTypeCheckingSuite extends FunSuite { - val testRelation = LocalRelation('a.int, 'b.string, 'c.boolean, 'd.array(StringType)) + val testRelation = LocalRelation( + 'intField.int, + 'stringField.string, + 'booleanField.boolean, + 'complexField.array(StringType)) - def checkError(expr: Expression, errorMessage: String): Unit = { + def assertError(expr: Expression, errorMessage: String): Unit = { val e = intercept[AnalysisException] { - checkAnalysis(expr) + assertSuccess(expr) } assert(e.getMessage.contains( s"cannot resolve '${expr.prettyString}' due to data type mismatch:")) assert(e.getMessage.contains(errorMessage)) } - def checkAnalysis(expr: Expression): Unit = { - val analyzed = testRelation.select(expr.as("_c")).analyze + def assertSuccess(expr: Expression): Unit = { + val analyzed = testRelation.select(expr.as("c")).analyze SimpleAnalyzer.checkAnalysis(analyzed) } test("check types for unary arithmetic") { - checkError(UnaryMinus('b), "operator - accepts numeric type") - checkAnalysis(Sqrt('b)) // We will cast String to Double for sqrt - checkError(Sqrt('c), "function sqrt accepts numeric type") - checkError(Abs('b), "function abs accepts numeric type") - checkError(BitwiseNot('b), "operator ~ accepts integral type") + assertError(UnaryMinus('stringField), "operator - accepts numeric type") + assertSuccess(Sqrt('stringField)) // We will cast String to Double for sqrt + assertError(Sqrt('booleanField), "function sqrt accepts numeric type") + assertError(Abs('stringField), "function abs accepts numeric type") + assertError(BitwiseNot('stringField), "operator ~ accepts integral type") } test("check types for binary arithmetic") { // We will cast String to Double for binary arithmetic - checkAnalysis(Add('a, 'b)) - checkAnalysis(Subtract('a, 'b)) - checkAnalysis(Multiply('a, 'b)) - checkAnalysis(Divide('a, 'b)) - checkAnalysis(Remainder('a, 'b)) - // checkAnalysis(BitwiseAnd('a, 'b)) - - val msg = "differing types in BinaryArithmetic, IntegerType != BooleanType" - checkError(Add('a, 'c), msg) - checkError(Subtract('a, 'c), msg) - checkError(Multiply('a, 'c), msg) - checkError(Divide('a, 'c), msg) - checkError(Remainder('a, 'c), msg) - checkError(BitwiseAnd('a, 'c), msg) - checkError(BitwiseOr('a, 'c), msg) - checkError(BitwiseXor('a, 'c), msg) - checkError(MaxOf('a, 'c), msg) - checkError(MinOf('a, 'c), msg) - - checkError(Add('c, 'c), "operator + accepts numeric type") - checkError(Subtract('c, 'c), "operator - accepts numeric type") - checkError(Multiply('c, 'c), "operator * accepts numeric type") - checkError(Divide('c, 'c), "operator / accepts numeric type") - checkError(Remainder('c, 'c), "operator % accepts numeric type") - - checkError(BitwiseAnd('c, 'c), "operator & accepts integral type") - checkError(BitwiseOr('c, 'c), "operator | accepts integral type") - checkError(BitwiseXor('c, 'c), "operator ^ accepts integral type") - - checkError(MaxOf('d, 'd), "function maxOf accepts non-complex type") - checkError(MinOf('d, 'd), "function minOf accepts non-complex type") + assertSuccess(Add('intField, 'stringField)) + assertSuccess(Subtract('intField, 'stringField)) + assertSuccess(Multiply('intField, 'stringField)) + assertSuccess(Divide('intField, 'stringField)) + assertSuccess(Remainder('intField, 'stringField)) + // checkAnalysis(BitwiseAnd('intField, 'stringField)) + + def msg(caller: String) = s"differing types in $caller, IntegerType != BooleanType" + assertError(Add('intField, 'booleanField), msg("Add")) + assertError(Subtract('intField, 'booleanField), msg("Subtract")) + assertError(Multiply('intField, 'booleanField), msg("Multiply")) + assertError(Divide('intField, 'booleanField), msg("Divide")) + assertError(Remainder('intField, 'booleanField), msg("Remainder")) + assertError(BitwiseAnd('intField, 'booleanField), msg("BitwiseAnd")) + assertError(BitwiseOr('intField, 'booleanField), msg("BitwiseOr")) + assertError(BitwiseXor('intField, 'booleanField), msg("BitwiseXor")) + assertError(MaxOf('intField, 'booleanField), msg("MaxOf")) + assertError(MinOf('intField, 'booleanField), msg("MinOf")) + + assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type") + assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type") + assertError(Multiply('booleanField, 'booleanField), "operator * accepts numeric type") + assertError(Divide('booleanField, 'booleanField), "operator / accepts numeric type") + assertError(Remainder('booleanField, 'booleanField), "operator % accepts numeric type") + + assertError(BitwiseAnd('booleanField, 'booleanField), "operator & accepts integral type") + assertError(BitwiseOr('booleanField, 'booleanField), "operator | accepts integral type") + assertError(BitwiseXor('booleanField, 'booleanField), "operator ^ accepts integral type") + + assertError(MaxOf('complexField, 'complexField), "function maxOf accepts non-complex type") + assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type") } test("check types for predicates") { // EqualTo don't have type constraint - checkAnalysis(EqualTo('a, 'c)) - checkAnalysis(EqualNullSafe('a, 'c)) + assertSuccess(EqualTo('intField, 'booleanField)) + assertSuccess(EqualNullSafe('intField, 'booleanField)) // We will cast String to Double for binary comparison - checkAnalysis(LessThan('a, 'b)) - checkAnalysis(LessThanOrEqual('a, 'b)) - checkAnalysis(GreaterThan('a, 'b)) - checkAnalysis(GreaterThanOrEqual('a, 'b)) - - val msg = "differing types in BinaryComparison, IntegerType != BooleanType" - checkError(LessThan('a, 'c), msg) - checkError(LessThanOrEqual('a, 'c), msg) - checkError(GreaterThan('a, 'c), msg) - checkError(GreaterThanOrEqual('a, 'c), msg) - - checkError(LessThan('d, 'd), "operator < accepts non-complex type") - checkError(LessThanOrEqual('d, 'd), "operator <= accepts non-complex type") - checkError(GreaterThan('d, 'd), "operator > accepts non-complex type") - checkError(GreaterThanOrEqual('d, 'd), "operator >= accepts non-complex type") - - checkError(If('a, 'a, 'a), "type of predicate expression in If should be boolean") - checkError(If('c, 'a, 'b), "differing types in If, IntegerType != StringType") + assertSuccess(LessThan('intField, 'stringField)) + assertSuccess(LessThanOrEqual('intField, 'stringField)) + assertSuccess(GreaterThan('intField, 'stringField)) + assertSuccess(GreaterThanOrEqual('intField, 'stringField)) + + def msg(caller: String) = s"differing types in $caller, IntegerType != BooleanType" + assertError(LessThan('intField, 'booleanField), msg("LessThan")) + assertError(LessThanOrEqual('intField, 'booleanField), msg("LessThanOrEqual")) + assertError(GreaterThan('intField, 'booleanField), msg("GreaterThan")) + assertError(GreaterThanOrEqual('intField, 'booleanField), msg("GreaterThanOrEqual")) + + assertError(LessThan('complexField, 'complexField), "operator < accepts non-complex type") + assertError(LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type") + assertError(GreaterThan('complexField, 'complexField), "operator > accepts non-complex type") + assertError(GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type") + + assertError(If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean") + assertError(If('booleanField, 'intField, 'stringField), "differing types in If, IntegerType != StringType") // Will write tests for CaseWhen later, // as the error reporting of it is not handle by the new interface for now From 3affbd8746aaf3a38aca0a9b65198525b973aa64 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 29 May 2015 15:31:37 +0800 Subject: [PATCH 09/17] more fixes --- .../catalyst/analysis/TypeCheckResult.scala | 4 +- .../sql/catalyst/expressions/arithmetic.scala | 3 +- .../sql/catalyst/expressions/predicates.scala | 5 +- .../ExpressionTypeCheckingSuite.scala | 57 +++++++++++-------- 4 files changed, 41 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala index 165bbabd06d1..d2474b0f3ad3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.analysis /** - * todo + * Represents the result of `Expression.checkInputDataTypes`. + * We will throw `AnalysisException` in `CheckAnalysis` if error message is not null. + * */ class TypeCheckResult(val errorMessage: String) extends AnyVal { def hasError: Boolean = errorMessage != null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 72dc8cc86679..641baf10ad8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -91,7 +91,8 @@ abstract class BinaryArithmetic extends BinaryExpression { override def checkInputDataTypes(): TypeCheckResult = { if (left.dataType != right.dataType) { TypeCheckResult.fail( - s"differing types in ${this.getClass.getSimpleName}, ${left.dataType} != ${right.dataType}") + s"differing types in ${this.getClass.getSimpleName} " + + s"(${left.dataType} and ${right.dataType}).") } else { checkTypesInternal(dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 874283c0f5a3..2b03d802c9e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -175,7 +175,8 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { override def checkInputDataTypes(): TypeCheckResult = { if (left.dataType != right.dataType) { TypeCheckResult.fail( - s"differing types in ${this.getClass.getSimpleName}, ${left.dataType} != ${right.dataType}") + s"differing types in ${this.getClass.getSimpleName} " + + s"(${left.dataType} and ${right.dataType}).") } else { TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) } @@ -275,7 +276,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi s"type of predicate expression in If should be boolean, not ${predicate.dataType}") } else if (trueValue.dataType != falseValue.dataType) { TypeCheckResult.fail( - s"differing types in If, ${trueValue.dataType} != ${falseValue.dataType}") + s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).") } else { TypeCheckResult.success } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala index c9481a5f9625..fb0fcf2ba2c1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala @@ -48,6 +48,11 @@ class ExpressionTypeCheckingSuite extends FunSuite { SimpleAnalyzer.checkAnalysis(analyzed) } + def assertErrorForDifferingTypes(expr: Expression): Unit = { + assertError(expr, + s"differing types in ${expr.getClass.getSimpleName} (IntegerType and BooleanType).") + } + test("check types for unary arithmetic") { assertError(UnaryMinus('stringField), "operator - accepts numeric type") assertSuccess(Sqrt('stringField)) // We will cast String to Double for sqrt @@ -65,17 +70,16 @@ class ExpressionTypeCheckingSuite extends FunSuite { assertSuccess(Remainder('intField, 'stringField)) // checkAnalysis(BitwiseAnd('intField, 'stringField)) - def msg(caller: String) = s"differing types in $caller, IntegerType != BooleanType" - assertError(Add('intField, 'booleanField), msg("Add")) - assertError(Subtract('intField, 'booleanField), msg("Subtract")) - assertError(Multiply('intField, 'booleanField), msg("Multiply")) - assertError(Divide('intField, 'booleanField), msg("Divide")) - assertError(Remainder('intField, 'booleanField), msg("Remainder")) - assertError(BitwiseAnd('intField, 'booleanField), msg("BitwiseAnd")) - assertError(BitwiseOr('intField, 'booleanField), msg("BitwiseOr")) - assertError(BitwiseXor('intField, 'booleanField), msg("BitwiseXor")) - assertError(MaxOf('intField, 'booleanField), msg("MaxOf")) - assertError(MinOf('intField, 'booleanField), msg("MinOf")) + assertErrorForDifferingTypes(Add('intField, 'booleanField)) + assertErrorForDifferingTypes(Subtract('intField, 'booleanField)) + assertErrorForDifferingTypes(Multiply('intField, 'booleanField)) + assertErrorForDifferingTypes(Divide('intField, 'booleanField)) + assertErrorForDifferingTypes(Remainder('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseAnd('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseOr('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseXor('intField, 'booleanField)) + assertErrorForDifferingTypes(MaxOf('intField, 'booleanField)) + assertErrorForDifferingTypes(MinOf('intField, 'booleanField)) assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type") assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type") @@ -102,19 +106,24 @@ class ExpressionTypeCheckingSuite extends FunSuite { assertSuccess(GreaterThan('intField, 'stringField)) assertSuccess(GreaterThanOrEqual('intField, 'stringField)) - def msg(caller: String) = s"differing types in $caller, IntegerType != BooleanType" - assertError(LessThan('intField, 'booleanField), msg("LessThan")) - assertError(LessThanOrEqual('intField, 'booleanField), msg("LessThanOrEqual")) - assertError(GreaterThan('intField, 'booleanField), msg("GreaterThan")) - assertError(GreaterThanOrEqual('intField, 'booleanField), msg("GreaterThanOrEqual")) - - assertError(LessThan('complexField, 'complexField), "operator < accepts non-complex type") - assertError(LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type") - assertError(GreaterThan('complexField, 'complexField), "operator > accepts non-complex type") - assertError(GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type") - - assertError(If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean") - assertError(If('booleanField, 'intField, 'stringField), "differing types in If, IntegerType != StringType") + assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) + assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField)) + assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) + assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) + + assertError( + LessThan('complexField, 'complexField), "operator < accepts non-complex type") + assertError( + LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type") + assertError( + GreaterThan('complexField, 'complexField), "operator > accepts non-complex type") + assertError( + GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type") + + assertError( + If('intField, 'stringField, 'stringField), + "type of predicate expression in If should be boolean") + assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) // Will write tests for CaseWhen later, // as the error reporting of it is not handle by the new interface for now From 6eaadff74af8e7e36793cf520ab399d4b6f067b9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 29 May 2015 16:10:32 +0800 Subject: [PATCH 10/17] add equal type constraint to EqualTo --- .../sql/catalyst/expressions/predicates.scala | 22 ++++++++++++++----- .../ExpressionTypeCheckingSuite.scala | 8 +++---- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2b03d802c9e5..343d4ab00c07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -178,10 +178,12 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { s"differing types in ${this.getClass.getSimpleName} " + s"(${left.dataType} and ${right.dataType}).") } else { - TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + checkTypesInternal(dataType) } } + protected def checkTypesInternal(t: DataType): TypeCheckResult + override def eval(input: Row): Any = { val evalE1 = left.eval(input) if (evalE1 == null) { @@ -203,8 +205,7 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "=" - // EqualTo don't need 2 equal orderable types - override def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.success + override protected def checkTypesInternal(t: DataType) = TypeCheckResult.success protected override def evalInternal(l: Any, r: Any) = { if (left.dataType != BinaryType) l == r @@ -216,8 +217,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp override def symbol: String = "<=>" override def nullable: Boolean = false - // EqualNullSafe don't need 2 equal orderable types - override def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.success + override protected def checkTypesInternal(t: DataType) = TypeCheckResult.success override def eval(input: Row): Any = { val l = left.eval(input) @@ -235,6 +235,9 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp case class LessThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<" + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + private lazy val ordering = TypeUtils.getOrdering(left.dataType) protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lt(evalE1, evalE2) @@ -243,6 +246,9 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<=" + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + private lazy val ordering = TypeUtils.getOrdering(left.dataType) protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lteq(evalE1, evalE2) @@ -251,6 +257,9 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">" + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + private lazy val ordering = TypeUtils.getOrdering(left.dataType) protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gt(evalE1, evalE2) @@ -259,6 +268,9 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">=" + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + private lazy val ordering = TypeUtils.getOrdering(left.dataType) protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gteq(evalE1, evalE2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala index fb0fcf2ba2c1..69be496a751a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala @@ -96,16 +96,16 @@ class ExpressionTypeCheckingSuite extends FunSuite { } test("check types for predicates") { - // EqualTo don't have type constraint - assertSuccess(EqualTo('intField, 'booleanField)) - assertSuccess(EqualNullSafe('intField, 'booleanField)) - // We will cast String to Double for binary comparison + assertSuccess(EqualTo('intField, 'stringField)) + assertSuccess(EqualNullSafe('intField, 'stringField)) assertSuccess(LessThan('intField, 'stringField)) assertSuccess(LessThanOrEqual('intField, 'stringField)) assertSuccess(GreaterThan('intField, 'stringField)) assertSuccess(GreaterThanOrEqual('intField, 'stringField)) + assertErrorForDifferingTypes(EqualTo('intField, 'booleanField)) + assertErrorForDifferingTypes(EqualNullSafe('intField, 'booleanField)) assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) From cffb67ca7baf16304c477fff241512f51b594b80 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 29 May 2015 18:09:32 +0800 Subject: [PATCH 11/17] to have resolved call the data type check function --- .../catalyst/analysis/HiveTypeCoercion.scala | 11 +++++----- .../sql/catalyst/expressions/Expression.scala | 2 +- .../sql/catalyst/expressions/arithmetic.scala | 20 ------------------- .../ExpressionTypeCheckingSuite.scala | 7 ++++--- 4 files changed, 11 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 9d88a9bab789..bb0ad2dc02f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -407,7 +407,7 @@ trait HiveTypeCoercion { Union(newLeft, newRight) // fix decimal precision for expressions - case q => q.transformExpressionsUp { + case q => q.transformExpressions { // Skip nodes whose children have not been resolved yet case e if !e.childrenResolved => e @@ -619,12 +619,13 @@ trait HiveTypeCoercion { */ object Division extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip nodes who's children have not been resolved yet or input types do not match. - case e if !e.childrenResolved || e.checkInputDataTypes().hasError => e + // Skip Divisions who has not been resolved yet, + // as this is an extra rule which should be applied at last. + case e if !e.resolved => e // Decimal and Double remain the same - case d: Divide if d.resolved && d.dataType == DoubleType => d - case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d + case d: Divide if d.dataType == DoubleType => d + case d: Divide if d.dataType.isInstanceOf[DecimalType] => d case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 93e3ffb0c162..6ab633f06764 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -49,7 +49,7 @@ abstract class Expression extends TreeNode[Expression] { * should override this if the resolution of this type of expression involves more than just * the resolution of its children. */ - lazy val resolved: Boolean = childrenResolved + lazy val resolved: Boolean = childrenResolved && !checkInputDataTypes().hasError /** * Returns the [[DataType]] of the result of evaluating this expression. It is diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 641baf10ad8f..0c2b7b4351da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -121,10 +121,6 @@ abstract class BinaryArithmetic extends BinaryExpression { case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" - // We will always cast fixed decimal to unlimited decimal - // for `Add` in `HiveTypeCoercion` - override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = TypeUtils.checkForNumericExpr(t, "operator " + symbol) @@ -136,10 +132,6 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "-" - // We will always cast fixed decimal to unlimited decimal - // for `Subtract` in `HiveTypeCoercion` - override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = TypeUtils.checkForNumericExpr(t, "operator " + symbol) @@ -151,10 +143,6 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "*" - // We will always cast fixed decimal to unlimited decimal - // for `Multiply` in `HiveTypeCoercion` - override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = TypeUtils.checkForNumericExpr(t, "operator " + symbol) @@ -167,10 +155,6 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic override def symbol: String = "/" override def nullable: Boolean = true - // We will always cast fixed decimal to unlimited decimal - // for `Divide` in `HiveTypeCoercion` - override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = TypeUtils.checkForNumericExpr(t, "operator " + symbol) @@ -198,10 +182,6 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet override def symbol: String = "%" override def nullable: Boolean = true - // We will always cast fixed decimal to unlimited decimal - // for `Remainder` in `HiveTypeCoercion` - override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = TypeUtils.checkForNumericExpr(t, "operator " + symbol) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala index 69be496a751a..6f7c3dfb3d50 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{BooleanType, StringType} +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types.StringType + import org.scalatest.FunSuite From 888302544212d61e5dcf691e2d4e4c92a880c731 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 29 May 2015 18:54:06 +0800 Subject: [PATCH 12/17] apply type check interface to CaseWhen --- .../catalyst/analysis/HiveTypeCoercion.scala | 41 ++++++++++--------- .../sql/catalyst/expressions/predicates.scala | 33 +++++++-------- .../ExpressionTypeCheckingSuite.scala | 12 +++++- 3 files changed, 48 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index bb0ad2dc02f4..aeb9e6bbf261 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -599,9 +599,8 @@ trait HiveTypeCoercion { // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. case Coalesce(es) if es.map(_.dataType).distinct.size > 1 => - val dt: Option[DataType] = Some(NullType) val types = es.map(_.dataType) - val rt = types.foldLeft(dt)((r, c) => r match { + val rt = types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { case None => None case Some(d) => findTightestCommonType(d, c) }) @@ -635,28 +634,30 @@ trait HiveTypeCoercion { * Coerces the type of different branches of a CASE WHEN statement to a common type. */ object CaseWhenCoercion extends Rule[LogicalPlan] { + import HiveTypeCoercion._ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case cw: CaseWhenLike if cw.childrenResolved && !cw.valueTypesEqual => + case cw: CaseWhenLike if cw.childrenResolved && cw.checkInputDataTypes().hasError => logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}") - val commonType = cw.valueTypes.reduce { (v1, v2) => - findTightestCommonType(v1, v2).getOrElse(sys.error( - s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) - } - val transformedBranches = cw.branches.sliding(2, 2).map { - case Seq(when, value) if value.dataType != commonType => - Seq(when, Cast(value, commonType)) - case Seq(elseVal) if elseVal.dataType != commonType => - Seq(Cast(elseVal, commonType)) - case s => s - }.reduce(_ ++ _) - cw match { - case _: CaseWhen => - CaseWhen(transformedBranches) - case CaseKeyWhen(key, _) => - CaseKeyWhen(key, transformedBranches) - } + cw.valueTypes.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case None => None + case Some(d) => findTightestCommonType(d, c) + }).map { commonType => + val transformedBranches = cw.branches.sliding(2, 2).map { + case Seq(when, value) if value.dataType != commonType => + Seq(when, Cast(value, commonType)) + case Seq(elseVal) if elseVal.dataType != commonType => + Seq(Cast(elseVal, commonType)) + case s => s + }.reduce(_ ++ _) + cw match { + case _: CaseWhen => + CaseWhen(transformedBranches) + case CaseKeyWhen(key, _) => + CaseKeyWhen(key, transformedBranches) + } + }.getOrElse(cw) case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved => val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 343d4ab00c07..d2e8f1dcf3a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -325,7 +325,18 @@ trait CaseWhenLike extends Expression { def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) def valueTypesEqual: Boolean = valueTypes.distinct.size == 1 - override def dataType: DataType = valueTypes.head + override def checkInputDataTypes(): TypeCheckResult = { + if (valueTypes.distinct.size > 1) { + TypeCheckResult.fail( + "THEN and ELSE expressions should all be same type or coercible to a common type") + } else { + checkTypesInternal() + } + } + + protected def checkTypesInternal(): TypeCheckResult + + override def dataType: DataType = thenList.head.dataType override def nullable: Boolean = { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. @@ -347,14 +358,11 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { override def children: Seq[Expression] = branches - override def checkInputDataTypes(): TypeCheckResult = { - if (!whenList.forall(_.dataType == BooleanType)) { - TypeCheckResult.fail(s"WHEN expressions should all be boolean type") - } else if (!valueTypesEqual) { - TypeCheckResult.fail( - "THEN and ELSE expressions should all be same type or coercible to a common type") - } else { + override protected def checkTypesInternal(): TypeCheckResult = { + if (whenList.forall(_.dataType == BooleanType)) { TypeCheckResult.success + } else { + TypeCheckResult.fail(s"WHEN expressions in CaseWhen should all be boolean type") } } @@ -399,14 +407,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW override def children: Seq[Expression] = key +: branches - override def checkInputDataTypes(): TypeCheckResult = { - if (!valueTypesEqual) { - TypeCheckResult.fail( - "THEN and ELSE expressions should all be same type or coercible to a common type") - } else { - TypeCheckResult.success - } - } + override protected def checkTypesInternal(): TypeCheckResult = TypeCheckResult.success /** Written in imperative fashion for performance considerations. */ override def eval(input: Row): Any = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala index 6f7c3dfb3d50..3cf67021fc93 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala @@ -126,7 +126,15 @@ class ExpressionTypeCheckingSuite extends FunSuite { "type of predicate expression in If should be boolean") assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) - // Will write tests for CaseWhen later, - // as the error reporting of it is not handle by the new interface for now + assertError( + CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'complexField)), + "THEN and ELSE expressions should all be same type or coercible to a common type") + assertError( + CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'complexField)), + "THEN and ELSE expressions should all be same type or coercible to a common type") + assertError( + CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)), + "WHEN expressions in CaseWhen should all be boolean type") + } } From 3bee1574b3d4f597d19d159de855ce90eacafa14 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 30 May 2015 13:49:01 +0800 Subject: [PATCH 13/17] and decimal type coercion rule for binary comparison --- .../catalyst/analysis/HiveTypeCoercion.scala | 20 +++++-------------- .../sql/catalyst/expressions/predicates.scala | 5 +++++ 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index aeb9e6bbf261..6ce582919e9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -323,6 +323,7 @@ trait HiveTypeCoercion { * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) * sum(e1) p1 + 10 s1 * avg(e1) p1 + 4 s1 + 4 + * compare max(p1, p2) max(s1, s2) * * Catalyst also has unlimited-precision decimals. For those, all ops return unlimited precision. * @@ -441,21 +442,10 @@ trait HiveTypeCoercion { DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) ) - case LessThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + val resultType = DecimalType(max(p1, p2), max(s1, s2)) + b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index d2e8f1dcf3a3..54a5ae9c3bb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -202,6 +202,11 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { sys.error(s"BinaryComparisons must either override eval or evalInternal") } +object BinaryComparison { + def unapply(b: BinaryComparison): Option[(Expression, Expression)] = + Some((b.left, b.right)) +} + case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "=" From 0808fd2e947c51d97287197e3ac1bf987f011383 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 30 May 2015 14:01:11 +0800 Subject: [PATCH 14/17] make constrcutor of TypeCheckResult private --- .../apache/spark/sql/catalyst/analysis/TypeCheckResult.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala index d2474b0f3ad3..653015154fc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql.catalyst.analysis /** * Represents the result of `Expression.checkInputDataTypes`. * We will throw `AnalysisException` in `CheckAnalysis` if error message is not null. + * Use [[TypeCheckResult.success]] and [[TypeCheckResult.fail]] to instantiate this. * */ -class TypeCheckResult(val errorMessage: String) extends AnyVal { +class TypeCheckResult private (val errorMessage: String) extends AnyVal { def hasError: Boolean = errorMessage != null } From 39929d9871c9fdb39eb4a6fee94df22b7e070358 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 30 May 2015 14:06:32 +0800 Subject: [PATCH 15/17] add todo --- .../org/apache/spark/sql/catalyst/expressions/Expression.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6ab633f06764..8c8a3fde9cb8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -90,6 +90,8 @@ abstract class Expression extends TreeNode[Expression] { /** * Check the input data types, returns `TypeCheckResult.success` if it's valid, * or return a `TypeCheckResult` with an error message if invalid. + * TODO: we should remove the default implementation and implement it for all + * expressions with proper error message. */ def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.success } From b917275e79301595b27364737465417d98f3b953 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 30 May 2015 14:16:36 +0800 Subject: [PATCH 16/17] rebase --- core/src/test/scala/org/apache/spark/SparkFunSuite.scala | 4 ++-- .../catalyst/expressions/ExpressionTypeCheckingSuite.scala | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 8cb344332668..9be9db01c7de 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -30,8 +30,8 @@ private[spark] abstract class SparkFunSuite extends FunSuite with Logging { * Log the suite name and the test name before and after each test. * * Subclasses should never override this method. If they wish to run - * custom code before and after each test, they should should mix in - * the {{org.scalatest.BeforeAndAfter}} trait instead. + * custom code before and after each test, they should mix in the + * {{org.scalatest.BeforeAndAfter}} trait instead. */ final protected override def withFixture(test: NoArgTest): Outcome = { val testName = test.text diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala index 3cf67021fc93..0aca2ea2111a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -24,10 +25,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types.StringType -import org.scalatest.FunSuite - - -class ExpressionTypeCheckingSuite extends FunSuite { +class ExpressionTypeCheckingSuite extends SparkFunSuite { val testRelation = LocalRelation( 'intField.int, From b5ff31b0dde66ed24634dc8773dfafb11b95ee50 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 1 Jun 2015 13:56:05 +0800 Subject: [PATCH 17/17] address comments --- .../sql/catalyst/analysis/CheckAnalysis.scala | 10 +- .../catalyst/analysis/HiveTypeCoercion.scala | 91 ++++++++++--------- .../catalyst/analysis/TypeCheckResult.scala | 26 ++++-- .../sql/catalyst/expressions/Expression.scala | 22 +++-- .../sql/catalyst/expressions/arithmetic.scala | 21 ++++- .../sql/catalyst/expressions/predicates.scala | 41 ++++++--- .../spark/sql/catalyst/util/TypeUtils.scala | 12 +-- .../analysis/DecimalPrecisionSuite.scala | 6 +- .../analysis/HiveTypeCoercionSuite.scala | 15 ++- .../ExpressionTypeCheckingSuite.scala | 9 +- .../apache/spark/sql/json/InferSchema.scala | 2 +- .../org/apache/spark/sql/json/JsonRDD.scala | 2 +- 12 files changed, 153 insertions(+), 104 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 5b689f22bedb..c0695ae36942 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -62,10 +62,12 @@ trait CheckAnalysis { val from = operator.inputSet.map(_.name).mkString(", ") a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") - case e: Expression if e.checkInputDataTypes.hasError => - e.failAnalysis( - s"cannot resolve '${e.prettyString}' due to data type mismatch: " + - e.checkInputDataTypes.errorMessage) + case e: Expression if e.checkInputDataTypes().isFailure => + e.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(message) => + e.failAnalysis( + s"cannot resolve '${e.prettyString}' due to data type mismatch: $message") + } case c: Cast if !c.resolved => failAnalysis( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 6ce582919e9c..b064600e94fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -41,7 +41,7 @@ object HiveTypeCoercion { * with primitive types, because in that case the precision and scale of the result depends on * the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]]. */ - val findTightestCommonType: (DataType, DataType) => Option[DataType] = { + val findTightestCommonTypeOfTwo: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) @@ -57,6 +57,17 @@ object HiveTypeCoercion { case _ => None } + + /** + * Find the tightest common type of a set of types by continuously applying + * `findTightestCommonTypeOfTwo` on these types. + */ + private def findTightestCommonType(types: Seq[DataType]) = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case None => None + case Some(d) => findTightestCommonTypeOfTwo(d, c) + }) + } } /** @@ -180,7 +191,7 @@ trait HiveTypeCoercion { case (l, r) if l.dataType != r.dataType => logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") - findTightestCommonType(l.dataType, r.dataType).map { widestType => + findTightestCommonTypeOfTwo(l.dataType, r.dataType).map { widestType => val newLeft = if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)() val newRight = @@ -217,7 +228,7 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e case b: BinaryExpression if b.left.dataType != b.right.dataType => - findTightestCommonType(b.left.dataType, b.right.dataType).map { widestType => + findTightestCommonTypeOfTwo(b.left.dataType, b.right.dataType).map { widestType => val newLeft = if (b.left.dataType == widestType) b.left else Cast(b.left, widestType) val newRight = @@ -323,7 +334,6 @@ trait HiveTypeCoercion { * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) * sum(e1) p1 + 10 s1 * avg(e1) p1 + 4 s1 + 4 - * compare max(p1, p2) max(s1, s2) * * Catalyst also has unlimited-precision decimals. For those, all ops return unlimited precision. * @@ -442,10 +452,18 @@ trait HiveTypeCoercion { DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) ) + // When we compare 2 decimal types with different precisions, cast them to the smallest + // common precision. case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => val resultType = DecimalType(max(p1, p2), max(s1, s2)) b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) + case b @ BinaryComparison(e1 @ DecimalType.Fixed(_, _), e2) + if e2.dataType == DecimalType.Unlimited => + b.makeCopy(Array(Cast(e1, DecimalType.Unlimited), e2)) + case b @ BinaryComparison(e1, e2 @ DecimalType.Fixed(_, _)) + if e1.dataType == DecimalType.Unlimited => + b.makeCopy(Array(e1, Cast(e2, DecimalType.Unlimited))) // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles @@ -560,7 +578,7 @@ trait HiveTypeCoercion { case a @ CreateArray(children) if !a.resolved => val commonType = a.childTypes.reduce( - (a, b) => findTightestCommonType(a, b).getOrElse(StringType)) + (a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType)) CreateArray( children.map(c => if (c.dataType == commonType) c else Cast(c, commonType))) @@ -590,12 +608,8 @@ trait HiveTypeCoercion { // compatible with every child column. case Coalesce(es) if es.map(_.dataType).distinct.size > 1 => val types = es.map(_.dataType) - val rt = types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case None => None - case Some(d) => findTightestCommonType(d, c) - }) - rt match { - case Some(finaldt) => Coalesce(es.map(Cast(_, finaldt))) + findTightestCommonType(types) match { + case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}") } @@ -608,7 +622,7 @@ trait HiveTypeCoercion { */ object Division extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip Divisions who has not been resolved yet, + // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. case e if !e.resolved => e @@ -624,47 +638,36 @@ trait HiveTypeCoercion { * Coerces the type of different branches of a CASE WHEN statement to a common type. */ object CaseWhenCoercion extends Rule[LogicalPlan] { - import HiveTypeCoercion._ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case cw: CaseWhenLike if cw.childrenResolved && cw.checkInputDataTypes().hasError => - logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}") - cw.valueTypes.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case None => None - case Some(d) => findTightestCommonType(d, c) - }).map { commonType => - val transformedBranches = cw.branches.sliding(2, 2).map { + case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => + logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") + val maybeCommonType = findTightestCommonType(c.valueTypes) + maybeCommonType.map { commonType => + val castedBranches = c.branches.grouped(2).map { case Seq(when, value) if value.dataType != commonType => Seq(when, Cast(value, commonType)) case Seq(elseVal) if elseVal.dataType != commonType => Seq(Cast(elseVal, commonType)) - case s => s + case other => other }.reduce(_ ++ _) - cw match { - case _: CaseWhen => - CaseWhen(transformedBranches) - case CaseKeyWhen(key, _) => - CaseKeyWhen(key, transformedBranches) + c match { + case _: CaseWhen => CaseWhen(castedBranches) + case CaseKeyWhen(key, _) => CaseKeyWhen(key, castedBranches) } - }.getOrElse(cw) - - case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved => - val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) => - findTightestCommonType(v1, v2).getOrElse(sys.error( - s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) - } - val transformedBranches = ckw.branches.sliding(2, 2).map { - case Seq(when, then) if when.dataType != commonType => - Seq(Cast(when, commonType), then) - case s => s - }.reduce(_ ++ _) - val transformedKey = if (ckw.key.dataType != commonType) { - Cast(ckw.key, commonType) - } else { - ckw.key - } - CaseKeyWhen(transformedKey, transformedBranches) + }.getOrElse(c) + + case c: CaseKeyWhen if c.childrenResolved && !c.resolved => + val maybeCommonType = findTightestCommonType((c.key +: c.whenList).map(_.dataType)) + maybeCommonType.map { commonType => + val castedBranches = c.branches.grouped(2).map { + case Seq(when, then) if when.dataType != commonType => + Seq(Cast(when, commonType), then) + case other => other + }.reduce(_ ++ _) + CaseKeyWhen(Cast(c.key, commonType), castedBranches) + }.getOrElse(c) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala index 653015154fc1..79c3528a522d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala @@ -19,15 +19,27 @@ package org.apache.spark.sql.catalyst.analysis /** * Represents the result of `Expression.checkInputDataTypes`. - * We will throw `AnalysisException` in `CheckAnalysis` if error message is not null. - * Use [[TypeCheckResult.success]] and [[TypeCheckResult.fail]] to instantiate this. - * + * We will throw `AnalysisException` in `CheckAnalysis` if `isFailure` is true. */ -class TypeCheckResult private (val errorMessage: String) extends AnyVal { - def hasError: Boolean = errorMessage != null +trait TypeCheckResult { + def isFailure: Boolean = !isSuccess + def isSuccess: Boolean } object TypeCheckResult { - val success: TypeCheckResult = new TypeCheckResult(null) - def fail(msg: String): TypeCheckResult = new TypeCheckResult(msg) + + /** + * Represents the successful result of `Expression.checkInputDataTypes`. + */ + object TypeCheckSuccess extends TypeCheckResult { + def isSuccess: Boolean = true + } + + /** + * Represents the failing result of `Expression.checkInputDataTypes`, + * with a error message to show the reason of failure. + */ + case class TypeCheckFailure(message: String) extends TypeCheckResult { + def isSuccess: Boolean = false + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 8c8a3fde9cb8..4ed0697b6f82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -45,11 +45,12 @@ abstract class Expression extends TreeNode[Expression] { /** * Returns `true` if this expression and all its children have been resolved to a specific schema - * and `false` if it still contains any unresolved placeholders. Implementations of expressions - * should override this if the resolution of this type of expression involves more than just - * the resolution of its children. + * and input data types checking passed, and `false` if it still contains any unresolved + * placeholders or has data types mismatch. + * Implementations of expressions should override this if the resolution of this type of + * expression involves more than just the resolution of its children and type checking. */ - lazy val resolved: Boolean = childrenResolved && !checkInputDataTypes().hasError + lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess /** * Returns the [[DataType]] of the result of evaluating this expression. It is @@ -88,18 +89,19 @@ abstract class Expression extends TreeNode[Expression] { } /** - * Check the input data types, returns `TypeCheckResult.success` if it's valid, - * or return a `TypeCheckResult` with an error message if invalid. + * Checks the input data types, returns `TypeCheckResult.success` if it's valid, + * or returns a `TypeCheckResult` with an error message if invalid. + * Note: it's not valid to call this method until `childrenResolved == true` * TODO: we should remove the default implementation and implement it for all * expressions with proper error message. */ - def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.success + def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess } abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { self: Product => - def symbol: String = sys.error(s"BinaryExpressions must either override toString or symbol") + def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol") override def foldable: Boolean = left.foldable && right.foldable @@ -137,9 +139,9 @@ trait ExpectsInputTypes { def expectedChildTypes: Seq[DataType] - override def checkInputDataTypes: TypeCheckResult = { + override def checkInputDataTypes(): TypeCheckResult = { // We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`, // so type mismatch error won't be reported here, but for underling `Cast`s. - TypeCheckResult.success + TypeCheckResult.TypeCheckSuccess } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 0c2b7b4351da..2ac53f8f6613 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -38,7 +38,7 @@ abstract class UnaryArithmetic extends UnaryExpression { } protected def evalInternal(evalE: Any): Any = - sys.error(s"UnaryArithmetics must either override eval or evalInternal") + sys.error(s"UnaryArithmetics must override either eval or evalInternal") } case class UnaryMinus(child: Expression) extends UnaryArithmetic { @@ -90,7 +90,7 @@ abstract class BinaryArithmetic extends BinaryExpression { override def checkInputDataTypes(): TypeCheckResult = { if (left.dataType != right.dataType) { - TypeCheckResult.fail( + TypeCheckResult.TypeCheckFailure( s"differing types in ${this.getClass.getSimpleName} " + s"(${left.dataType} and ${right.dataType}).") } else { @@ -115,12 +115,15 @@ abstract class BinaryArithmetic extends BinaryExpression { } protected def evalInternal(evalE1: Any, evalE2: Any): Any = - sys.error(s"BinaryArithmetics must either override eval or evalInternal") + sys.error(s"BinaryArithmetics must override either eval or evalInternal") } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + protected def checkTypesInternal(t: DataType) = TypeUtils.checkForNumericExpr(t, "operator " + symbol) @@ -132,6 +135,9 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "-" + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + protected def checkTypesInternal(t: DataType) = TypeUtils.checkForNumericExpr(t, "operator " + symbol) @@ -143,6 +149,9 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "*" + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + protected def checkTypesInternal(t: DataType) = TypeUtils.checkForNumericExpr(t, "operator " + symbol) @@ -155,6 +164,9 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic override def symbol: String = "/" override def nullable: Boolean = true + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + protected def checkTypesInternal(t: DataType) = TypeUtils.checkForNumericExpr(t, "operator " + symbol) @@ -182,6 +194,9 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet override def symbol: String = "%" override def nullable: Boolean = true + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + protected def checkTypesInternal(t: DataType) = TypeUtils.checkForNumericExpr(t, "operator " + symbol) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 54a5ae9c3bb4..807021d50e8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -174,7 +174,7 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { override def checkInputDataTypes(): TypeCheckResult = { if (left.dataType != right.dataType) { - TypeCheckResult.fail( + TypeCheckResult.TypeCheckFailure( s"differing types in ${this.getClass.getSimpleName} " + s"(${left.dataType} and ${right.dataType}).") } else { @@ -199,7 +199,7 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { } protected def evalInternal(evalE1: Any, evalE2: Any): Any = - sys.error(s"BinaryComparisons must either override eval or evalInternal") + sys.error(s"BinaryComparisons must override either eval or evalInternal") } object BinaryComparison { @@ -210,7 +210,7 @@ object BinaryComparison { case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "=" - override protected def checkTypesInternal(t: DataType) = TypeCheckResult.success + override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess protected override def evalInternal(l: Any, r: Any) = { if (left.dataType != BinaryType) l == r @@ -220,9 +220,10 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<=>" + override def nullable: Boolean = false - override protected def checkTypesInternal(t: DataType) = TypeCheckResult.success + override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess override def eval(input: Row): Any = { val l = left.eval(input) @@ -289,13 +290,13 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def checkInputDataTypes(): TypeCheckResult = { if (predicate.dataType != BooleanType) { - TypeCheckResult.fail( + TypeCheckResult.TypeCheckFailure( s"type of predicate expression in If should be boolean, not ${predicate.dataType}") } else if (trueValue.dataType != falseValue.dataType) { - TypeCheckResult.fail( + TypeCheckResult.TypeCheckFailure( s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).") } else { - TypeCheckResult.success + TypeCheckResult.TypeCheckSuccess } } @@ -326,16 +327,16 @@ trait CaseWhenLike extends Expression { branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) - // both then and else val should be considered. + // both then and else expressions should be considered. def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) def valueTypesEqual: Boolean = valueTypes.distinct.size == 1 override def checkInputDataTypes(): TypeCheckResult = { - if (valueTypes.distinct.size > 1) { - TypeCheckResult.fail( - "THEN and ELSE expressions should all be same type or coercible to a common type") - } else { + if (valueTypesEqual) { checkTypesInternal() + } else { + TypeCheckResult.TypeCheckFailure( + "THEN and ELSE expressions should all be same type or coercible to a common type") } } @@ -365,9 +366,12 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { override protected def checkTypesInternal(): TypeCheckResult = { if (whenList.forall(_.dataType == BooleanType)) { - TypeCheckResult.success + TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.fail(s"WHEN expressions in CaseWhen should all be boolean type") + val index = whenList.indexWhere(_.dataType != BooleanType) + TypeCheckResult.TypeCheckFailure( + s"WHEN expressions in CaseWhen should all be boolean type, " + + s"but the ${index + 1}th when expression's type is ${whenList(index)}") } } @@ -412,7 +416,14 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW override def children: Seq[Expression] = key +: branches - override protected def checkTypesInternal(): TypeCheckResult = TypeCheckResult.success + override protected def checkTypesInternal(): TypeCheckResult = { + if ((key +: whenList).map(_.dataType).distinct.size > 1) { + TypeCheckResult.TypeCheckFailure( + "key and WHEN expressions should all be same type or coercible to a common type") + } else { + TypeCheckResult.TypeCheckSuccess + } + } /** Written in imperative fashion for performance considerations. */ override def eval(input: Row): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 26df4fbfcf31..0bb12d2039ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -26,25 +26,25 @@ import org.apache.spark.sql.types._ object TypeUtils { def checkForNumericExpr(t: DataType, caller: String): TypeCheckResult = { if (t.isInstanceOf[NumericType] || t == NullType) { - TypeCheckResult.success + TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.fail(s"$caller accepts numeric types, not $t") + TypeCheckResult.TypeCheckFailure(s"$caller accepts numeric types, not $t") } } def checkForBitwiseExpr(t: DataType, caller: String): TypeCheckResult = { if (t.isInstanceOf[IntegralType] || t == NullType) { - TypeCheckResult.success + TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.fail(s"$caller accepts integral types, not $t") + TypeCheckResult.TypeCheckFailure(s"$caller accepts integral types, not $t") } } def checkForOrderingExpr(t: DataType, caller: String): TypeCheckResult = { if (t.isInstanceOf[AtomicType] || t == NullType) { - TypeCheckResult.success + TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.fail(s"$caller accepts non-complex types, not $t") + TypeCheckResult.TypeCheckFailure(s"$caller accepts non-complex types, not $t") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 1b8d18ded225..7bac97b7894f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -92,8 +92,10 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { } test("Comparison operations") { - checkComparison(LessThan(i, d1), DecimalType.Unlimited) - checkComparison(LessThanOrEqual(d1, d2), DecimalType.Unlimited) + checkComparison(EqualTo(i, d1), DecimalType(10, 1)) + checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2)) + checkComparison(LessThan(i, d1), DecimalType(10, 1)) + checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2)) checkComparison(GreaterThan(d2, u), DecimalType.Unlimited) checkComparison(GreaterThanOrEqual(d1, f), DoubleType) checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index a0798428db09..0df446636ea8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -28,11 +28,11 @@ class HiveTypeCoercionSuite extends PlanTest { test("tightest common bound for types") { def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) { - var found = HiveTypeCoercion.findTightestCommonType(t1, t2) + var found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) assert(found == tightestCommon, s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. - found = HiveTypeCoercion.findTightestCommonType(t2, t1) + found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t2, t1) assert(found == tightestCommon, s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found") } @@ -140,13 +140,10 @@ class HiveTypeCoercionSuite extends PlanTest { CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) - // Will remove exception expectation in PR#6405 - intercept[RuntimeException] { - ruleTest(cwc, - CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), - CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) - ) - } + ruleTest(cwc, + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) + ) } test("type coercion simplification for equal to") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala index 0aca2ea2111a..dcb3635c5cca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala @@ -103,8 +103,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(GreaterThan('intField, 'stringField)) assertSuccess(GreaterThanOrEqual('intField, 'stringField)) - assertErrorForDifferingTypes(EqualTo('intField, 'booleanField)) - assertErrorForDifferingTypes(EqualNullSafe('intField, 'booleanField)) + // We will transform EqualTo with numeric and boolean types to CaseKeyWhen + assertSuccess(EqualTo('intField, 'booleanField)) + assertSuccess(EqualNullSafe('intField, 'booleanField)) + + assertError(EqualTo('intField, 'complexField), "differing types") + assertError(EqualNullSafe('intField, 'complexField), "differing types") + assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala index 06aa19ef09bd..565d10247f10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala @@ -147,7 +147,7 @@ private[sql] object InferSchema { * Returns the most general data type for two given data types. */ private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { - HiveTypeCoercion.findTightestCommonType(t1, t2).getOrElse { + HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { case (other: DataType, NullType) => other diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 95eb1174b1dd..7e1e21f5fbb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -155,7 +155,7 @@ private[sql] object JsonRDD extends Logging { * Returns the most general data type for two given data types. */ private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { - HiveTypeCoercion.findTightestCommonType(t1, t2) match { + HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) match { case Some(commonType) => commonType case None => // t1 or t2 is a StructType, ArrayType, or an unexpected type.