diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala index 934e53703e241..e7f0e571804d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.Cast -import org.apache.spark.sql.catalyst.expressions.postgreSQL.PostgreCastStringToBoolean +import org.apache.spark.sql.catalyst.expressions.postgreSQL.PostgreCastToBoolean import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf @@ -27,19 +27,19 @@ import org.apache.spark.sql.types.{BooleanType, StringType} object PostgreSQLDialect { val postgreSQLDialectRules: List[Rule[LogicalPlan]] = - CastStringToBoolean :: + CastToBoolean :: Nil - object CastStringToBoolean extends Rule[LogicalPlan] with Logging { + object CastToBoolean extends Rule[LogicalPlan] with Logging { override def apply(plan: LogicalPlan): LogicalPlan = { // The SQL configuration `spark.sql.dialect` can be changed in runtime. // To make sure the configuration is effective, we have to check it during rule execution. val conf = SQLConf.get if (conf.usePostgreSQLDialect) { plan.transformExpressions { - case Cast(child, dataType, _) - if dataType == BooleanType && child.dataType == StringType => - PostgreCastStringToBoolean(child) + case Cast(child, dataType, timeZoneId) + if child.dataType != BooleanType && dataType == BooleanType => + PostgreCastToBoolean(child, timeZoneId) } } else { plan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 862b2bb515a19..322695036592c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -273,7 +273,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType) // [[func]] assumes the input is no longer null because eval already does the null check. - @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) + @inline protected def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) private lazy val dateFormatter = DateFormatter(zoneId) private lazy val timestampFormatter = TimestampFormatter.getFractionFormatter(zoneId) @@ -376,7 +376,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } // UDFToBoolean - private[this] def castToBoolean(from: DataType): Any => Any = from match { + protected[this] def castToBoolean(from: DataType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => { if (StringUtils.isTrueString(s)) { @@ -781,7 +781,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) @@ -791,7 +791,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull` // in parameter list, because the returned code will be put in null safe evaluation region. - private[this] type CastFunction = (ExprValue, ExprValue, ExprValue) => Block + protected type CastFunction = (ExprValue, ExprValue, ExprValue) => Block private[this] def nullSafeCastFunction( from: DataType, @@ -1233,7 +1233,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def timestampToDoubleCode(ts: ExprValue): Block = code"$ts / (double)$MICROS_PER_SECOND" - private[this] def castToBooleanCode(from: DataType): CastFunction = from match { + protected[this] def castToBooleanCode(from: DataType): CastFunction = from match { case StringType => val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}" (c, evPrim, evNull) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastStringToBoolean.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastStringToBoolean.scala deleted file mode 100644 index 0e87707d01e47..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastStringToBoolean.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.expressions.postgreSQL - -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant, UnaryExpression} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, JavaCode} -import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.postgreSQL.StringUtils -import org.apache.spark.sql.types.{BooleanType, DataType, StringType} -import org.apache.spark.unsafe.types.UTF8String - -case class PostgreCastStringToBoolean(child: Expression) - extends UnaryExpression with NullIntolerant { - - override def checkInputDataTypes(): TypeCheckResult = { - if (child.dataType == StringType) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure( - s"The expression ${getClass.getSimpleName} only accepts string input data type") - } - } - - override def nullSafeEval(input: Any): Any = { - val s = input.asInstanceOf[UTF8String].trim().toLowerCase() - if (StringUtils.isTrueString(s)) { - true - } else if (StringUtils.isFalseString(s)) { - false - } else { - null - } - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}" - val eval = child.genCode(ctx) - val javaType = JavaCode.javaType(dataType) - val preprocessedString = ctx.freshName("preprocessedString") - val castCode = - code""" - boolean ${ev.isNull} = ${eval.isNull}; - $javaType ${ev.value} = false; - if (!${eval.isNull}) { - UTF8String $preprocessedString = ${eval.value}.trim().toLowerCase(); - if ($stringUtils.isTrueString($preprocessedString)) { - ${ev.value} = true; - } else if ($stringUtils.isFalseString($preprocessedString)) { - ${ev.value} = false; - } else { - ${ev.isNull} = true; - } - } - """ - ev.copy(code = eval.code + castCode) - } - - override def dataType: DataType = BooleanType - - override def nullable: Boolean = true - - override def toString: String = s"PostgreCastStringToBoolean($child as ${dataType.simpleString})" - - override def sql: String = s"CAST(${child.sql} AS ${dataType.sql})" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastToBoolean.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastToBoolean.scala new file mode 100644 index 0000000000000..20559ba3cd79e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastToBoolean.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.postgreSQL + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.{CastBase, Expression, TimeZoneAwareExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.util.postgreSQL.StringUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +case class PostgreCastToBoolean(child: Expression, timeZoneId: Option[String]) + extends CastBase { + + override protected def ansiEnabled = + throw new UnsupportedOperationException("PostgreSQL dialect doesn't support ansi mode") + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case StringType | IntegerType | NullType => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure(s"cannot cast type ${child.dataType} to boolean") + } + + override def castToBoolean(from: DataType): Any => Any = from match { + case StringType => + buildCast[UTF8String](_, str => { + val s = str.trim().toLowerCase() + if (StringUtils.isTrueString(s)) { + true + } else if (StringUtils.isFalseString(s)) { + false + } else { + throw new IllegalArgumentException(s"invalid input syntax for type boolean: $s") + } + }) + case IntegerType => + super.castToBoolean(from) + } + + override def castToBooleanCode(from: DataType): CastFunction = from match { + case StringType => + val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}" + (c, evPrim, evNull) => + code""" + if ($stringUtils.isTrueString($c.trim().toLowerCase())) { + $evPrim = true; + } else if ($stringUtils.isFalseString($c.trim().toLowerCase())) { + $evPrim = false; + } else { + throw new IllegalArgumentException("invalid input syntax for type boolean: $c"); + } + """ + + case IntegerType => + super.castToBooleanCode(from) + } + + override def dataType: DataType = BooleanType + + override def nullable: Boolean = child.nullable + + override def toString: String = s"PostgreCastToBoolean($child as ${dataType.simpleString})" + + override def sql: String = s"CAST(${child.sql} AS ${dataType.sql})" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala index 175904da21969..6c5218b379f31 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala @@ -16,44 +16,58 @@ */ package org.apache.spark.sql.catalyst.expressions.postgreSQL +import java.sql.{Date, Timestamp} + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, Literal} class CastSuite extends SparkFunSuite with ExpressionEvalHelper { - private def checkPostgreCastStringToBoolean(v: Any, expected: Any): Unit = { - checkEvaluation(PostgreCastStringToBoolean(Literal(v)), expected) + private def checkPostgreCastToBoolean(v: Any, expected: Any): Unit = { + checkEvaluation(PostgreCastToBoolean(Literal(v), None), expected) } test("cast string to boolean") { - checkPostgreCastStringToBoolean("true", true) - checkPostgreCastStringToBoolean("tru", true) - checkPostgreCastStringToBoolean("tr", true) - checkPostgreCastStringToBoolean("t", true) - checkPostgreCastStringToBoolean("tRUe", true) - checkPostgreCastStringToBoolean(" tRue ", true) - checkPostgreCastStringToBoolean(" tRu ", true) - checkPostgreCastStringToBoolean("yes", true) - checkPostgreCastStringToBoolean("ye", true) - checkPostgreCastStringToBoolean("y", true) - checkPostgreCastStringToBoolean("1", true) - checkPostgreCastStringToBoolean("on", true) + checkPostgreCastToBoolean("true", true) + checkPostgreCastToBoolean("tru", true) + checkPostgreCastToBoolean("tr", true) + checkPostgreCastToBoolean("t", true) + checkPostgreCastToBoolean("tRUe", true) + checkPostgreCastToBoolean(" tRue ", true) + checkPostgreCastToBoolean(" tRu ", true) + checkPostgreCastToBoolean("yes", true) + checkPostgreCastToBoolean("ye", true) + checkPostgreCastToBoolean("y", true) + checkPostgreCastToBoolean("1", true) + checkPostgreCastToBoolean("on", true) + + checkPostgreCastToBoolean("false", false) + checkPostgreCastToBoolean("fals", false) + checkPostgreCastToBoolean("fal", false) + checkPostgreCastToBoolean("fa", false) + checkPostgreCastToBoolean("f", false) + checkPostgreCastToBoolean(" fAlse ", false) + checkPostgreCastToBoolean(" fAls ", false) + checkPostgreCastToBoolean(" FAlsE ", false) + checkPostgreCastToBoolean("no", false) + checkPostgreCastToBoolean("n", false) + checkPostgreCastToBoolean("0", false) + checkPostgreCastToBoolean("off", false) + checkPostgreCastToBoolean("of", false) - checkPostgreCastStringToBoolean("false", false) - checkPostgreCastStringToBoolean("fals", false) - checkPostgreCastStringToBoolean("fal", false) - checkPostgreCastStringToBoolean("fa", false) - checkPostgreCastStringToBoolean("f", false) - checkPostgreCastStringToBoolean(" fAlse ", false) - checkPostgreCastStringToBoolean(" fAls ", false) - checkPostgreCastStringToBoolean(" FAlsE ", false) - checkPostgreCastStringToBoolean("no", false) - checkPostgreCastStringToBoolean("n", false) - checkPostgreCastStringToBoolean("0", false) - checkPostgreCastStringToBoolean("off", false) - checkPostgreCastStringToBoolean("of", false) + intercept[IllegalArgumentException](PostgreCastToBoolean(Literal("o"), None).eval()) + intercept[IllegalArgumentException](PostgreCastToBoolean(Literal("abc"), None).eval()) + intercept[IllegalArgumentException](PostgreCastToBoolean(Literal(""), None).eval()) + } - checkPostgreCastStringToBoolean("o", null) - checkPostgreCastStringToBoolean("abc", null) - checkPostgreCastStringToBoolean("", null) + test("unsupported data types to cast to boolean") { + assert(PostgreCastToBoolean(Literal(new Timestamp(1)), None).checkInputDataTypes().isFailure) + assert(PostgreCastToBoolean(Literal(new Date(1)), None).checkInputDataTypes().isFailure) + assert(PostgreCastToBoolean(Literal(1.toLong), None).checkInputDataTypes().isFailure) + assert(PostgreCastToBoolean(Literal(1.toShort), None).checkInputDataTypes().isFailure) + assert(PostgreCastToBoolean(Literal(1.toByte), None).checkInputDataTypes().isFailure) + assert(PostgreCastToBoolean(Literal(BigDecimal(1.0)), None).checkInputDataTypes().isFailure) + assert(PostgreCastToBoolean(Literal(1.toDouble), None).checkInputDataTypes().isFailure) + assert(PostgreCastToBoolean(Literal(1.toFloat), None).checkInputDataTypes().isFailure) } } diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/boolean.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/boolean.sql.out index 203806d43368a..e5f3425efc458 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/boolean.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/boolean.sql.out @@ -53,9 +53,10 @@ true -- !query 6 SELECT boolean('test') AS error -- !query 6 schema -struct +struct<> -- !query 6 output -NULL +java.lang.IllegalArgumentException +invalid input syntax for type boolean: test -- !query 7 @@ -69,9 +70,10 @@ false -- !query 8 SELECT boolean('foo') AS error -- !query 8 schema -struct +struct<> -- !query 8 output -NULL +java.lang.IllegalArgumentException +invalid input syntax for type boolean: foo -- !query 9 @@ -93,9 +95,10 @@ true -- !query 11 SELECT boolean('yeah') AS error -- !query 11 schema -struct +struct<> -- !query 11 output -NULL +java.lang.IllegalArgumentException +invalid input syntax for type boolean: yeah -- !query 12 @@ -117,9 +120,10 @@ false -- !query 14 SELECT boolean('nay') AS error -- !query 14 schema -struct +struct<> -- !query 14 output -NULL +java.lang.IllegalArgumentException +invalid input syntax for type boolean: nay -- !query 15 @@ -149,25 +153,28 @@ false -- !query 18 SELECT boolean('o') AS error -- !query 18 schema -struct +struct<> -- !query 18 output -NULL +java.lang.IllegalArgumentException +invalid input syntax for type boolean: o -- !query 19 SELECT boolean('on_') AS error -- !query 19 schema -struct +struct<> -- !query 19 output -NULL +java.lang.IllegalArgumentException +invalid input syntax for type boolean: on_ -- !query 20 SELECT boolean('off_') AS error -- !query 20 schema -struct +struct<> -- !query 20 output -NULL +java.lang.IllegalArgumentException +invalid input syntax for type boolean: off_ -- !query 21 @@ -181,9 +188,10 @@ true -- !query 22 SELECT boolean('11') AS error -- !query 22 schema -struct +struct<> -- !query 22 output -NULL +java.lang.IllegalArgumentException +invalid input syntax for type boolean: 11 -- !query 23 @@ -197,17 +205,19 @@ false -- !query 24 SELECT boolean('000') AS error -- !query 24 schema -struct +struct<> -- !query 24 output -NULL +java.lang.IllegalArgumentException +invalid input syntax for type boolean: 000 -- !query 25 SELECT boolean('') AS error -- !query 25 schema -struct +struct<> -- !query 25 output -NULL +java.lang.IllegalArgumentException +invalid input syntax for type boolean: -- !query 26 @@ -310,17 +320,19 @@ true false -- !query 38 SELECT boolean(string(' tru e ')) AS invalid -- !query 38 schema -struct +struct<> -- !query 38 output -NULL +java.lang.IllegalArgumentException +invalid input syntax for type boolean: tru e -- !query 39 SELECT boolean(string('')) AS invalid -- !query 39 schema -struct +struct<> -- !query 39 output -NULL +java.lang.IllegalArgumentException +invalid input syntax for type boolean: -- !query 40 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala index 1354dcfda45fe..7056f483609a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala @@ -36,7 +36,7 @@ class PostgreSQLDialectQuerySuite extends QueryTest with SharedSparkSession { } Seq("o", "abc", "").foreach { input => - checkAnswer(sql(s"select cast('$input' as boolean)"), Row(null)) + intercept[IllegalArgumentException](sql(s"select cast('$input' as boolean)").collect()) } } }