diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3f0d77ad6322a..39e5ef0d0a2be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -69,6 +69,7 @@ class Analyzer(catalog: Catalog, typeCoercionRules ++ extendedRules : _*), Batch("Check Analysis", Once, + CheckCast :: CheckResolution :: CheckAggregation :: Nil: _*), @@ -76,6 +77,20 @@ class Analyzer(catalog: Catalog, EliminateAnalysisOperators) ) + /** + * Makes sure datatype cast is legitimate, if not throw an exception + */ + object CheckCast extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan.transform { + case q: LogicalPlan => + q transformExpressions { + case cast @ Cast(child, dataType) if !cast.resolve(child.dataType, dataType) => + throw new AnalysisException(s"can not cast from ${child.dataType} to $dataType!") + case p => p + } + } + } + /** * Makes sure all attributes and logical plans have been resolved. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index b1bc858478ee1..5c0269aac831f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -51,7 +51,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def resolvableNullability(from: Boolean, to: Boolean) = !from || to - private[this] def resolve(from: DataType, to: DataType): Boolean = { + private[sql] def resolve(from: DataType, to: DataType): Boolean = { (from, to) match { case (from, to) if from == to => true diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 405b200d05412..8a670faae379b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -27,7 +27,7 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.{SparkFiles, SparkException} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.hive._ @@ -66,6 +66,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } + test("SPARK-5649: added a rule to check datatypes cast") { + intercept[AnalysisException] { + sql("select cast(key as binary) from src").collect() + } + } + createQueryTest("! operator", """ |SELECT a FROM (