diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala index e7f0e571804d3..d3d17dc6b65b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala @@ -19,15 +19,15 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.Cast -import org.apache.spark.sql.catalyst.expressions.postgreSQL.PostgreCastToBoolean +import org.apache.spark.sql.catalyst.expressions.postgreSQL.{PostgreCastToBoolean, PostgreCastToInteger} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BooleanType, StringType} +import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType} object PostgreSQLDialect { val postgreSQLDialectRules: List[Rule[LogicalPlan]] = - CastToBoolean :: + CastToBoolean :: CastToInt :: Nil object CastToBoolean extends Rule[LogicalPlan] with Logging { @@ -46,4 +46,21 @@ object PostgreSQLDialect { } } } + + object CastToInt extends Rule[LogicalPlan] with Logging { + override def apply(plan: LogicalPlan): LogicalPlan = { + // The SQL configuration `spark.sql.dialect` can be changed in runtime. + // To make sure the configuration is effective, we have to check it during rule execution. + val conf = SQLConf.get + if (conf.usePostgreSQLDialect) { + plan.transformExpressions { + case Cast(child, dataType, timeZoneId) + if dataType == IntegerType => + PostgreCastToInteger(child, timeZoneId) + } + } else { + plan + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f3b58fa3137b1..bcc500fa10ee3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -488,7 +488,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } // IntConverter - private[this] def castToInt(from: DataType): Any => Any = from match { + private[catalyst] def castToInt(from: DataType): Any => Any = from match { case StringType => val result = new IntWrapper() buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null) @@ -1394,7 +1394,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = (short) $c;" } - private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { + private[catalyst] def castToIntCode( + from: DataType, + ctx: CodegenContext): CastFunction = from match { case StringType => val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastToInteger.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastToInteger.scala new file mode 100644 index 0000000000000..c48a4e307db15 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastToInteger.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.postgreSQL + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.{CastBase, Expression, TimeZoneAwareExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.UTF8String.IntWrapper + +case class PostgreCastToInteger(child: Expression, timeZoneId: Option[String]) + extends CastBase{ + override def dataType: DataType = IntegerType + + override protected def ansiEnabled: Boolean = + throw new AnalysisException("PostgreSQL dialect doesn't support ansi mode") + + override def nullable: Boolean = true + + /** Returns a copy of this expression with the specified timeZoneId. */ + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case ByteType | TimestampType | DateType => + TypeCheckResult.TypeCheckFailure(s"Cannot cast type ${child.dataType} to int") + case _ => TypeCheckResult.TypeCheckSuccess + } + + override def castToInt(from: DataType): Any => Any = from match { + case StringType => + val result = new IntWrapper() + buildCast[UTF8String](_, s => if (s.toInt(result)) { + result.value + } else { + throw new AnalysisException(s"invalid input syntax for type numeric: $s") + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1 else 0) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) + } + + override def castToIntCode( + from: DataType, + ctx: CodegenContext): CastFunction = from match { + case StringType => + val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) + (c, evPrim, evNull) => + code""" + UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); + if ($c.toInt($wrapper)) { + $evPrim = $wrapper.value; + } else { + $evNull = throw new AnalysisException(s"invalid input syntax for type numeric: $c; + } + $wrapper = null; + """ + case BooleanType => + (c, evPrim, _) => code"$evPrim = $c ? 1 : 0;" + case _: NumericType => + (c, evPrim, _) => code"$evPrim = (int) $c;" + } + + override def toString: String = s"PostgreCastToInt($child as ${dataType.simpleString})" + + override def sql: String = s"CAST(${child.sql} AS ${dataType.sql})" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala index 6c5218b379f31..09da20fb1612a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala @@ -70,4 +70,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(PostgreCastToBoolean(Literal(1.toDouble), None).checkInputDataTypes().isFailure) assert(PostgreCastToBoolean(Literal(1.toFloat), None).checkInputDataTypes().isFailure) } + + test("Unsupported data types to cast to integer") { + assert(PostgreCastToInteger(Literal(new Timestamp(1)), None).checkInputDataTypes().isFailure) + assert(PostgreCastToInteger(Literal(new Date(1)), None).checkInputDataTypes().isFailure) + assert(PostgreCastToInteger(Literal(1.toByte), None).checkInputDataTypes().isFailure) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala index 7056f483609a9..c3a67d762eb0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala @@ -39,4 +39,16 @@ class PostgreSQLDialectQuerySuite extends QueryTest with SharedSparkSession { intercept[IllegalArgumentException](sql(s"select cast('$input' as boolean)").collect()) } } + + test("Cast to integer") { + assert(intercept[AnalysisException]( + sql(s"SELECT cast(cast('1' as byte) as int)") + ).getMessage contains "Cannot cast") + assert(intercept[AnalysisException]( + sql(s"SELECT cast(cast('1-1-1' as date) as int)") + ).getMessage contains "Cannot cast") + assert(intercept[AnalysisException]( + sql(s"SELECT cast(cast('1-1-1' as timestamp) as int)") + ).getMessage contains "Cannot cast") + } }