From 0e7f28172b30b249f027de858ee5fe34fa598569 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 23 Sep 2019 16:57:29 +0800 Subject: [PATCH 01/10] dialect --- .../sql/catalyst/analysis/TypeCoercion.scala | 4 +- .../spark/sql/catalyst/expressions/Cast.scala | 10 ++- .../spark/sql/catalyst/util/StringUtils.scala | 28 ++++-- .../apache/spark/sql/internal/SQLConf.scala | 24 ++++-- .../catalyst/analysis/TypeCoercionSuite.scala | 8 +- .../sql/catalyst/expressions/CastSuite.scala | 85 ++++++++++++------- .../apache/spark/sql/SQLQueryTestSuite.scala | 3 +- .../ThriftServerQueryTestSuite.scala | 2 +- 8 files changed, 108 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 8ea6e1b0f1808..3d3c93fbfb234 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -677,8 +677,10 @@ object TypeCoercion { case d: Divide if d.dataType == DoubleType => d case d: Divide if d.dataType.isInstanceOf[DecimalType] => d case Divide(left, right) if isNumericOrNull(left) && isNumericOrNull(right) => + val preferIntegralDivision = + conf.getConf(SQLConf.DIALECT) == SQLConf.Dialect.POSTGRESQL.toString (left.dataType, right.dataType) match { - case (_: IntegralType, _: IntegralType) if conf.preferIntegralDivision => + case (_: IntegralType, _: IntegralType) if preferIntegralDivision => IntegralDivide(left, right) case _ => Divide(Cast(left, DoubleType), Cast(right, DoubleType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 118f261de775d..fea65745336f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -391,10 +391,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // UDFToBoolean private[this] def castToBoolean(from: DataType): Any => Any = from match { case StringType => + val dialect = SQLConf.get.getConf(SQLConf.DIALECT) buildCast[UTF8String](_, s => { - if (StringUtils.isTrueString(s)) { + if (StringUtils.isTrueString(s, dialect)) { true - } else if (StringUtils.isFalseString(s)) { + } else if (StringUtils.isFalseString(s, dialect)) { false } else { null @@ -1250,11 +1251,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToBooleanCode(from: DataType): CastFunction = from match { case StringType => val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}" + val dialect = SQLConf.get.getConf(SQLConf.DIALECT) (c, evPrim, evNull) => code""" - if ($stringUtils.isTrueString($c)) { + if ($stringUtils.isTrueString($c, "$dialect")) { $evPrim = true; - } else if ($stringUtils.isFalseString($c)) { + } else if ($stringUtils.isFalseString($c, "$dialect")) { $evPrim = false; } else { $evNull = true; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index a14ae540f5056..27f78b37a67a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -65,16 +65,34 @@ object StringUtils extends Logging { "(?s)" + out.result() // (?s) enables dotall mode, causing "." to match new lines } - // "true", "yes", "1", "false", "no", "0", and unique prefixes of these strings are accepted. private[this] val trueStrings = - Set("true", "tru", "tr", "t", "yes", "ye", "y", "on", "1").map(UTF8String.fromString) + Set("t", "true", "y", "yes", "1").map(UTF8String.fromString) + // "true", "yes", "1", "false", "no", "0", and unique prefixes of these strings are accepted. + private[this] val trueStringsOfPostgreSQL = + Set("true", "tru", "tr", "t", "yes", "ye", "y", "on", "1").map (UTF8String.fromString) private[this] val falseStrings = + Set("f", "false", "n", "no", "0").map(UTF8String.fromString) + private[this] val falseStringsOfPostgreSQL = Set("false", "fals", "fal", "fa", "f", "no", "n", "off", "of", "0").map(UTF8String.fromString) - // scalastyle:off caselocale - def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase.trim()) - def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase.trim()) + def isTrueString(s: UTF8String, dialect: String): Boolean = { + SQLConf.Dialect.withName(dialect) match { + case SQLConf.Dialect.SPARK => + trueStrings.contains(s.toLowerCase) + case SQLConf.Dialect.POSTGRESQL => + trueStringsOfPostgreSQL.contains(s.toLowerCase.trim()) + } + } + + def isFalseString(s: UTF8String, dialect: String): Boolean = { + SQLConf.Dialect.withName(dialect) match { + case SQLConf.Dialect.SPARK => + falseStrings.contains(s.toLowerCase) + case SQLConf.Dialect.POSTGRESQL => + falseStringsOfPostgreSQL.contains(s.toLowerCase.trim()) + } + } // scalastyle:on caselocale /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0ec661fc16c88..a2185370c4ab6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1589,12 +1589,22 @@ object SQLConf { .booleanConf .createWithDefault(false) - val PREFER_INTEGRAL_DIVISION = buildConf("spark.sql.function.preferIntegralDivision") - .internal() - .doc("When true, will perform integral division with the / operator " + - "if both sides are integral types. This is for PostgreSQL test cases only.") - .booleanConf - .createWithDefault(false) + object Dialect extends Enumeration { + val SPARK, POSTGRESQL = Value + } + + val DIALECT = + buildConf("spark.sql.dialect") + .doc("The specific features of the SQL language to be adopted, which are available when " + + "accessing the given database. Currently, Spark supports two database dialects, `Spark` " + + "and `PostgreSQL`. With `PostgreSQL` dialect, Spark will: " + + "1. perform integral division with the / operator if both sides are integral types; " + + "2. accept \"true\", \"yes\", \"1\", \"false\", \"no\", \"0\", and unique prefixes as " + + "input and trim input for the boolean data type.") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(Dialect.values.map(_.toString)) + .createWithDefault(Dialect.SPARK.toString) val ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION = buildConf("spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation") @@ -2418,8 +2428,6 @@ class SQLConf extends Serializable with Logging { def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) - def preferIntegralDivision: Boolean = getConf(PREFER_INTEGRAL_DIVISION) - def allowCreatingManagedTableUsingNonemptyLocation: Boolean = getConf(ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 50c38145ae21d..5d62754b1cec4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1483,15 +1483,15 @@ class TypeCoercionSuite extends AnalysisTest { test("SPARK-28395 Division operator support integral division") { val rules = Seq(FunctionArgumentConversion, Division(conf)) - Seq(true, false).foreach { preferIntegralDivision => - withSQLConf(SQLConf.PREFER_INTEGRAL_DIVISION.key -> s"$preferIntegralDivision") { - val result1 = if (preferIntegralDivision) { + Seq(SQLConf.Dialect.SPARK, SQLConf.Dialect.POSTGRESQL).foreach { dialect => + withSQLConf(SQLConf.DIALECT.key -> dialect.toString) { + val result1 = if (dialect == SQLConf.Dialect.POSTGRESQL) { IntegralDivide(1L, 1L) } else { Divide(Cast(1L, DoubleType), Cast(1L, DoubleType)) } ruleTest(rules, Divide(1L, 1L), result1) - val result2 = if (preferIntegralDivision) { + val result2 = if (dialect == SQLConf.Dialect.POSTGRESQL) { IntegralDivide(1, Cast(1, ShortType)) } else { Divide(Cast(1, DoubleType), Cast(Cast(1, ShortType), DoubleType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index ffb14e2838687..4a05e780cd146 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -818,37 +818,60 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { "interval 1 years 3 months -3 days") } - test("cast string to boolean") { - checkCast("true", true) - checkCast("tru", true) - checkCast("tr", true) - checkCast("t", true) - checkCast("tRUe", true) - checkCast(" tRue ", true) - checkCast(" tRu ", true) - checkCast("yes", true) - checkCast("ye", true) - checkCast("y", true) - checkCast("1", true) - checkCast("on", true) - - checkCast("false", false) - checkCast("fals", false) - checkCast("fal", false) - checkCast("fa", false) - checkCast("f", false) - checkCast(" fAlse ", false) - checkCast(" fAls ", false) - checkCast(" FAlsE ", false) - checkCast("no", false) - checkCast("n", false) - checkCast("0", false) - checkCast("off", false) - checkCast("of", false) - - checkEvaluation(cast("o", BooleanType), null) - checkEvaluation(cast("abc", BooleanType), null) - checkEvaluation(cast("", BooleanType), null) + test("cast string to boolean with Spark dialect") { + withSQLConf(SQLConf.DIALECT.key -> SQLConf.Dialect.SPARK.toString) { + checkCast("t", true) + checkCast("true", true) + checkCast("tRUe", true) + checkCast("y", true) + checkCast("yes", true) + checkCast("1", true) + + checkCast("f", false) + checkCast("false", false) + checkCast("FAlsE", false) + checkCast("n", false) + checkCast("no", false) + checkCast("0", false) + + checkEvaluation(cast("abc", BooleanType), null) + checkEvaluation(cast("", BooleanType), null) + } + } + + test("cast string to boolean with PostgreSQL dialect") { + withSQLConf(SQLConf.DIALECT.key -> SQLConf.Dialect.POSTGRESQL.toString) { + checkCast("true", true) + checkCast("tru", true) + checkCast("tr", true) + checkCast("t", true) + checkCast("tRUe", true) + checkCast(" tRue ", true) + checkCast(" tRu ", true) + checkCast("yes", true) + checkCast("ye", true) + checkCast("y", true) + checkCast("1", true) + checkCast("on", true) + + checkCast("false", false) + checkCast("fals", false) + checkCast("fal", false) + checkCast("fa", false) + checkCast("f", false) + checkCast(" fAlse ", false) + checkCast(" fAls ", false) + checkCast(" FAlsE ", false) + checkCast("no", false) + checkCast("n", false) + checkCast("0", false) + checkCast("off", false) + checkCast("of", false) + + checkEvaluation(cast("o", BooleanType), null) + checkEvaluation(cast("abc", BooleanType), null) + checkEvaluation(cast("", BooleanType), null) + } } test("SPARK-16729 type checking for casting to date type") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 1a41dd95a5700..788535e0bd1bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -311,8 +311,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession { // PostgreSQL enabled cartesian product by default. localSparkSession.conf.set(SQLConf.CROSS_JOINS_ENABLED.key, true) localSparkSession.conf.set(SQLConf.ANSI_ENABLED.key, true) - localSparkSession.conf.set(SQLConf.PREFER_INTEGRAL_DIVISION.key, true) - localSparkSession.conf.set(SQLConf.ANSI_ENABLED.key, true) + localSparkSession.conf.set(SQLConf.DIALECT.key, SQLConf.Dialect.POSTGRESQL.toString) case _ => } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala index 381b8f2324ca6..f14ea958c1b2d 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala @@ -111,7 +111,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite { // PostgreSQL enabled cartesian product by default. statement.execute(s"SET ${SQLConf.CROSS_JOINS_ENABLED.key} = true") statement.execute(s"SET ${SQLConf.ANSI_ENABLED.key} = true") - statement.execute(s"SET ${SQLConf.PREFER_INTEGRAL_DIVISION.key} = true") + statement.execute(s"SET ${SQLConf.DIALECT.key} = ${SQLConf.Dialect.POSTGRESQL.toString}") case _ => } From 094d58ec330d4874045e364705813c3fd521e99e Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 24 Sep 2019 19:45:05 +0800 Subject: [PATCH 02/10] new package and expression --- .../sql/catalyst/analysis/Analyzer.scala | 1 + .../catalyst/analysis/PostgreSQLDialect.scala | 45 +++++++++++ .../sql/catalyst/analysis/TypeCoercion.scala | 3 +- .../spark/sql/catalyst/expressions/Cast.scala | 10 +-- .../expressions/postgreSQL/Cast.scala | 78 +++++++++++++++++++ .../spark/sql/catalyst/util/StringUtils.scala | 25 +++--- .../apache/spark/sql/internal/SQLConf.scala | 2 + .../sql/catalyst/expressions/CastSuite.scala | 71 +++++------------ .../expressions/postgreSQL/CastSuite.scala | 59 ++++++++++++++ .../sql/PostgreSQLDialectQuerySuite.scala | 43 ++++++++++ 10 files changed, 260 insertions(+), 77 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/Cast.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7c81185388d02..65696e84a7587 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -227,6 +227,7 @@ class Analyzer( ResolveRandomSeed :: TypeCoercion.typeCoercionRules(conf) ++ extendedResolutionRules : _*), + Batch("PostgreSQl dialect", Once, PostgreSQLDialect.postgreSQLDialectRules(conf): _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), Batch("Nondeterministic", Once, PullOutNondeterministic), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala new file mode 100644 index 0000000000000..9fd05bee1c167 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.expressions.postgreSQL.PostgreCastStringToBoolean +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{BooleanType, StringType} + +object PostgreSQLDialect { + def postgreSQLDialectRules(conf: SQLConf): List[Rule[LogicalPlan]] = + if (conf.usePostgreSQLDialect) { + postgreCastStringToBoolean(conf) :: + Nil + } else { + Nil + } + + case class postgreCastStringToBoolean(conf: SQLConf) extends Rule[LogicalPlan] with Logging { + override def apply(plan: LogicalPlan): LogicalPlan = { + plan.transformExpressions { + case Cast(child, dataType, _) if dataType == BooleanType && child.dataType == StringType => + PostgreCastStringToBoolean(child) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 3d3c93fbfb234..8ab25f2dfda6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -677,8 +677,7 @@ object TypeCoercion { case d: Divide if d.dataType == DoubleType => d case d: Divide if d.dataType.isInstanceOf[DecimalType] => d case Divide(left, right) if isNumericOrNull(left) && isNumericOrNull(right) => - val preferIntegralDivision = - conf.getConf(SQLConf.DIALECT) == SQLConf.Dialect.POSTGRESQL.toString + val preferIntegralDivision = conf.usePostgreSQLDialect (left.dataType, right.dataType) match { case (_: IntegralType, _: IntegralType) if preferIntegralDivision => IntegralDivide(left, right) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index fea65745336f1..118f261de775d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -391,11 +391,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // UDFToBoolean private[this] def castToBoolean(from: DataType): Any => Any = from match { case StringType => - val dialect = SQLConf.get.getConf(SQLConf.DIALECT) buildCast[UTF8String](_, s => { - if (StringUtils.isTrueString(s, dialect)) { + if (StringUtils.isTrueString(s)) { true - } else if (StringUtils.isFalseString(s, dialect)) { + } else if (StringUtils.isFalseString(s)) { false } else { null @@ -1251,12 +1250,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToBooleanCode(from: DataType): CastFunction = from match { case StringType => val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}" - val dialect = SQLConf.get.getConf(SQLConf.DIALECT) (c, evPrim, evNull) => code""" - if ($stringUtils.isTrueString($c, "$dialect")) { + if ($stringUtils.isTrueString($c)) { $evPrim = true; - } else if ($stringUtils.isFalseString($c, "$dialect")) { + } else if ($stringUtils.isFalseString($c)) { $evPrim = false; } else { $evNull = true; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/Cast.scala new file mode 100644 index 0000000000000..349ed98605622 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/Cast.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.postgreSQL + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, JavaCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.types.{BooleanType, DataType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +case class PostgreCastStringToBoolean(child: Expression) + extends UnaryExpression { + + override def checkInputDataTypes(): TypeCheckResult = { + if (child.dataType == StringType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"The expression ${getClass.getSimpleName} only accepts string input data type") + } + } + + override def nullSafeEval(input: Any): Any = { + val s = input.asInstanceOf[UTF8String] + if (StringUtils.postgreIsTrueString(s)) { + true + } else if (StringUtils.postgreIsFalseString(s)) { + false + } else { + null + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}" + val eval = child.genCode(ctx) + val javaType = JavaCode.javaType(dataType) + val castCode = + code""" + boolean ${ev.isNull} = ${eval.isNull}; + $javaType ${ev.value} = false; + if (!${eval.isNull}) { + if ($stringUtils.postgreIsTrueString(${eval.value})) { + ${ev.value} = true; + } else if ($stringUtils.postgreIsFalseString(${eval.value})) { + ${ev.value} = false; + } else { + ${ev.isNull} = true; + } + } + """ + ev.copy(code = eval.code + castCode) + } + + override def dataType: DataType = BooleanType + + override def nullable: Boolean = true + + override def toString: String = s"postgreCastStringToBoolean($child as ${dataType.simpleString})" + + override def sql: String = s"CAST(${child.sql} AS ${dataType.sql})" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index 27f78b37a67a4..1979f075de7c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -75,24 +75,17 @@ object StringUtils extends Logging { Set("f", "false", "n", "no", "0").map(UTF8String.fromString) private[this] val falseStringsOfPostgreSQL = Set("false", "fals", "fal", "fa", "f", "no", "n", "off", "of", "0").map(UTF8String.fromString) + // scalastyle:off caselocale - def isTrueString(s: UTF8String, dialect: String): Boolean = { - SQLConf.Dialect.withName(dialect) match { - case SQLConf.Dialect.SPARK => - trueStrings.contains(s.toLowerCase) - case SQLConf.Dialect.POSTGRESQL => - trueStringsOfPostgreSQL.contains(s.toLowerCase.trim()) - } - } + def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase) - def isFalseString(s: UTF8String, dialect: String): Boolean = { - SQLConf.Dialect.withName(dialect) match { - case SQLConf.Dialect.SPARK => - falseStrings.contains(s.toLowerCase) - case SQLConf.Dialect.POSTGRESQL => - falseStringsOfPostgreSQL.contains(s.toLowerCase.trim()) - } - } + def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase) + + def postgreIsTrueString(s: UTF8String): Boolean = + trueStringsOfPostgreSQL.contains(s.toLowerCase.trim()) + + def postgreIsFalseString(s: UTF8String): Boolean = + falseStringsOfPostgreSQL.contains(s.toLowerCase.trim()) // scalastyle:on caselocale /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a2185370c4ab6..57f68421cd2e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2441,6 +2441,8 @@ class SQLConf extends Serializable with Logging { def ansiEnabled: Boolean = getConf(ANSI_ENABLED) + def usePostgreSQLDialect: Boolean = getConf(DIALECT) == Dialect.POSTGRESQL.toString() + def nestedSchemaPruningEnabled: Boolean = getConf(NESTED_SCHEMA_PRUNING_ENABLED) def serializerNestedSchemaPruningEnabled: Boolean = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 4a05e780cd146..38aacb2c97b1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -819,59 +819,24 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast string to boolean with Spark dialect") { - withSQLConf(SQLConf.DIALECT.key -> SQLConf.Dialect.SPARK.toString) { - checkCast("t", true) - checkCast("true", true) - checkCast("tRUe", true) - checkCast("y", true) - checkCast("yes", true) - checkCast("1", true) - - checkCast("f", false) - checkCast("false", false) - checkCast("FAlsE", false) - checkCast("n", false) - checkCast("no", false) - checkCast("0", false) - - checkEvaluation(cast("abc", BooleanType), null) - checkEvaluation(cast("", BooleanType), null) - } - } - - test("cast string to boolean with PostgreSQL dialect") { - withSQLConf(SQLConf.DIALECT.key -> SQLConf.Dialect.POSTGRESQL.toString) { - checkCast("true", true) - checkCast("tru", true) - checkCast("tr", true) - checkCast("t", true) - checkCast("tRUe", true) - checkCast(" tRue ", true) - checkCast(" tRu ", true) - checkCast("yes", true) - checkCast("ye", true) - checkCast("y", true) - checkCast("1", true) - checkCast("on", true) - - checkCast("false", false) - checkCast("fals", false) - checkCast("fal", false) - checkCast("fa", false) - checkCast("f", false) - checkCast(" fAlse ", false) - checkCast(" fAls ", false) - checkCast(" FAlsE ", false) - checkCast("no", false) - checkCast("n", false) - checkCast("0", false) - checkCast("off", false) - checkCast("of", false) - - checkEvaluation(cast("o", BooleanType), null) - checkEvaluation(cast("abc", BooleanType), null) - checkEvaluation(cast("", BooleanType), null) - } + checkCast("t", true) + checkCast("true", true) + checkCast("tRUe", true) + checkCast("y", true) + checkCast("yes", true) + checkCast("1", true) + + checkCast("f", false) + checkCast("false", false) + checkCast("FAlsE", false) + checkCast("n", false) + checkCast("no", false) + checkCast("0", false) + + checkEvaluation(cast("abc", BooleanType), null) + checkEvaluation(cast("tru", BooleanType), null) + checkEvaluation(cast("fla", BooleanType), null) + checkEvaluation(cast("", BooleanType), null) } test("SPARK-16729 type checking for casting to date type") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala new file mode 100644 index 0000000000000..a8d5a3ad52fda --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.postgreSQL + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, Literal} + +class CastSuite extends SparkFunSuite with ExpressionEvalHelper { + private def checkPostgreCastStringToBoolean(v: Any, expected: Any): Unit = { + checkEvaluation(PostgreCastStringToBoolean(Literal(v)), expected) + } + + test("cast string to boolean with PostgreSQL dialect") { + checkPostgreCastStringToBoolean("true", true) + checkPostgreCastStringToBoolean("tru", true) + checkPostgreCastStringToBoolean("tr", true) + checkPostgreCastStringToBoolean("t", true) + checkPostgreCastStringToBoolean("tRUe", true) + checkPostgreCastStringToBoolean(" tRue ", true) + checkPostgreCastStringToBoolean(" tRu ", true) + checkPostgreCastStringToBoolean("yes", true) + checkPostgreCastStringToBoolean("ye", true) + checkPostgreCastStringToBoolean("y", true) + checkPostgreCastStringToBoolean("1", true) + checkPostgreCastStringToBoolean("on", true) + + checkPostgreCastStringToBoolean("false", false) + checkPostgreCastStringToBoolean("fals", false) + checkPostgreCastStringToBoolean("fal", false) + checkPostgreCastStringToBoolean("fa", false) + checkPostgreCastStringToBoolean("f", false) + checkPostgreCastStringToBoolean(" fAlse ", false) + checkPostgreCastStringToBoolean(" fAls ", false) + checkPostgreCastStringToBoolean(" FAlsE ", false) + checkPostgreCastStringToBoolean("no", false) + checkPostgreCastStringToBoolean("n", false) + checkPostgreCastStringToBoolean("0", false) + checkPostgreCastStringToBoolean("off", false) + checkPostgreCastStringToBoolean("of", false) + + checkPostgreCastStringToBoolean("o", null) + checkPostgreCastStringToBoolean("abc", null) + checkPostgreCastStringToBoolean("", null) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala new file mode 100644 index 0000000000000..4b800d4743030 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.SparkConf +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class PostgreSQLDialectQuerySuite extends QueryTest with SharedSparkSession { + + override def sparkConf: SparkConf = + super.sparkConf + .set(SQLConf.DIALECT.key, SQLConf.Dialect.POSTGRESQL.toString) + + test("cast string to boolean") { + Seq("true", "tru", "tr", "t", " tRue ", " tRu ", "yes", "ye", + "y", "1", "on").foreach { input => + checkAnswer(sql(s"select cast('$input' as boolean)"), Row(true)) + } + Seq("false", "fals", "fal", "fa", "f", " fAlse ", " fAls ", "no", "n", + "0", "off", "of").foreach { input => + checkAnswer(sql(s"select cast('$input' as boolean)"), Row(false)) + } + + Seq("o", "abc", "").foreach { input => + checkAnswer(sql(s"select cast('$input' as boolean)"), Row(null)) + } + } +} From 9f14680cef74823835f4bbb740799f8531a32386 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 24 Sep 2019 22:24:14 +0800 Subject: [PATCH 03/10] revise --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 65696e84a7587..0aed82e32021d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -227,7 +227,7 @@ class Analyzer( ResolveRandomSeed :: TypeCoercion.typeCoercionRules(conf) ++ extendedResolutionRules : _*), - Batch("PostgreSQl dialect", Once, PostgreSQLDialect.postgreSQLDialectRules(conf): _*), + Batch("PostgreSQL Dialect", Once, PostgreSQLDialect.postgreSQLDialectRules(conf): _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), Batch("Nondeterministic", Once, PullOutNondeterministic), From 5a69400b32d9bf8675f6b84e74b0efcbc5a096d7 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 25 Sep 2019 11:48:58 +0800 Subject: [PATCH 04/10] address comments --- .../catalyst/analysis/PostgreSQLDialect.scala | 4 +-- ...scala => PostgreCastStringToBoolean.scala} | 15 ++++----- .../spark/sql/catalyst/util/StringUtils.scala | 11 ------- .../util/postgreSQL/StringUtils.scala | 33 +++++++++++++++++++ 4 files changed, 42 insertions(+), 21 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/{Cast.scala => PostgreCastStringToBoolean.scala} (85%) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/postgreSQL/StringUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala index 9fd05bee1c167..9857f8621fc65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala @@ -28,13 +28,13 @@ import org.apache.spark.sql.types.{BooleanType, StringType} object PostgreSQLDialect { def postgreSQLDialectRules(conf: SQLConf): List[Rule[LogicalPlan]] = if (conf.usePostgreSQLDialect) { - postgreCastStringToBoolean(conf) :: + CastStringToBoolean(conf) :: Nil } else { Nil } - case class postgreCastStringToBoolean(conf: SQLConf) extends Rule[LogicalPlan] with Logging { + case class CastStringToBoolean(conf: SQLConf) extends Rule[LogicalPlan] with Logging { override def apply(plan: LogicalPlan): LogicalPlan = { plan.transformExpressions { case Cast(child, dataType, _) if dataType == BooleanType && child.dataType == StringType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastStringToBoolean.scala similarity index 85% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/Cast.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastStringToBoolean.scala index 349ed98605622..515f74f2df7be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastStringToBoolean.scala @@ -20,12 +20,11 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, JavaCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.catalyst.util.postgreSQL.StringUtils import org.apache.spark.sql.types.{BooleanType, DataType, StringType} import org.apache.spark.unsafe.types.UTF8String -case class PostgreCastStringToBoolean(child: Expression) - extends UnaryExpression { +case class PostgreCastStringToBoolean(child: Expression) extends UnaryExpression { override def checkInputDataTypes(): TypeCheckResult = { if (child.dataType == StringType) { @@ -38,9 +37,9 @@ case class PostgreCastStringToBoolean(child: Expression) override def nullSafeEval(input: Any): Any = { val s = input.asInstanceOf[UTF8String] - if (StringUtils.postgreIsTrueString(s)) { + if (StringUtils.isTrueString(s)) { true - } else if (StringUtils.postgreIsFalseString(s)) { + } else if (StringUtils.isFalseString(s)) { false } else { null @@ -56,9 +55,9 @@ case class PostgreCastStringToBoolean(child: Expression) boolean ${ev.isNull} = ${eval.isNull}; $javaType ${ev.value} = false; if (!${eval.isNull}) { - if ($stringUtils.postgreIsTrueString(${eval.value})) { + if ($stringUtils.isTrueString(${eval.value})) { ${ev.value} = true; - } else if ($stringUtils.postgreIsFalseString(${eval.value})) { + } else if ($stringUtils.isFalseString(${eval.value})) { ${ev.value} = false; } else { ${ev.isNull} = true; @@ -72,7 +71,7 @@ case class PostgreCastStringToBoolean(child: Expression) override def nullable: Boolean = true - override def toString: String = s"postgreCastStringToBoolean($child as ${dataType.simpleString})" + override def toString: String = s"PostgreCastStringToBoolean($child as ${dataType.simpleString})" override def sql: String = s"CAST(${child.sql} AS ${dataType.sql})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index 1979f075de7c5..3bda9a2a1fc48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -67,25 +67,14 @@ object StringUtils extends Logging { private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString) - // "true", "yes", "1", "false", "no", "0", and unique prefixes of these strings are accepted. - private[this] val trueStringsOfPostgreSQL = - Set("true", "tru", "tr", "t", "yes", "ye", "y", "on", "1").map (UTF8String.fromString) private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString) - private[this] val falseStringsOfPostgreSQL = - Set("false", "fals", "fal", "fa", "f", "no", "n", "off", "of", "0").map(UTF8String.fromString) // scalastyle:off caselocale def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase) def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase) - - def postgreIsTrueString(s: UTF8String): Boolean = - trueStringsOfPostgreSQL.contains(s.toLowerCase.trim()) - - def postgreIsFalseString(s: UTF8String): Boolean = - falseStringsOfPostgreSQL.contains(s.toLowerCase.trim()) // scalastyle:on caselocale /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/postgreSQL/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/postgreSQL/StringUtils.scala new file mode 100644 index 0000000000000..80d95a466f9f5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/postgreSQL/StringUtils.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util.postgreSQL + +import org.apache.spark.unsafe.types.UTF8String + +object StringUtils { + // "true", "yes", "1", "false", "no", "0", and unique prefixes of these strings are accepted. + private[this] val trueStrings = + Set("true", "tru", "tr", "t", "yes", "ye", "y", "on", "1").map (UTF8String.fromString) + + private[this] val falseStrings = + Set("false", "fals", "fal", "fa", "f", "no", "n", "off", "of", "0").map(UTF8String.fromString) + + def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.trim().toLowerCase()) + + def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.trim().toLowerCase()) +} From bb657e929b7b05486b78234ed93aad48d5de6524 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 25 Sep 2019 14:10:14 +0800 Subject: [PATCH 05/10] fix test failure --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../catalyst/analysis/PostgreSQLDialect.scala | 24 +++++++++++-------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0aed82e32021d..6e3af91a2fc7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -227,7 +227,7 @@ class Analyzer( ResolveRandomSeed :: TypeCoercion.typeCoercionRules(conf) ++ extendedResolutionRules : _*), - Batch("PostgreSQL Dialect", Once, PostgreSQLDialect.postgreSQLDialectRules(conf): _*), + Batch("PostgreSQL Dialect", Once, PostgreSQLDialect.postgreSQLDialectRules: _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), Batch("Nondeterministic", Once, PullOutNondeterministic), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala index 9857f8621fc65..c5d0fd5282516 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala @@ -26,19 +26,23 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, StringType} object PostgreSQLDialect { - def postgreSQLDialectRules(conf: SQLConf): List[Rule[LogicalPlan]] = - if (conf.usePostgreSQLDialect) { - CastStringToBoolean(conf) :: - Nil - } else { + val postgreSQLDialectRules: List[Rule[LogicalPlan]] = + CastStringToBoolean :: Nil - } - case class CastStringToBoolean(conf: SQLConf) extends Rule[LogicalPlan] with Logging { + object CastStringToBoolean extends Rule[LogicalPlan] with Logging { override def apply(plan: LogicalPlan): LogicalPlan = { - plan.transformExpressions { - case Cast(child, dataType, _) if dataType == BooleanType && child.dataType == StringType => - PostgreCastStringToBoolean(child) + // The SQL configuration `spark.sql.dialect` can be changed in runtime. + // To make sure the configuration is effective, we have to check it during rule execution. + val conf = SQLConf.get + if (conf.usePostgreSQLDialect) { + plan.transformExpressions { + case Cast(child, dataType, _) + if dataType == BooleanType && child.dataType == StringType => + PostgreCastStringToBoolean(child) + } + } else { + plan } } } From 82a9e4d684032641a8356a5baacf5bf9d2414e6a Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 25 Sep 2019 14:11:55 +0800 Subject: [PATCH 06/10] revise --- .../apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala index c5d0fd5282516..934e53703e241 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala @@ -38,7 +38,7 @@ object PostgreSQLDialect { if (conf.usePostgreSQLDialect) { plan.transformExpressions { case Cast(child, dataType, _) - if dataType == BooleanType && child.dataType == StringType => + if dataType == BooleanType && child.dataType == StringType => PostgreCastStringToBoolean(child) } } else { From 72a1539996c5d90daec0649e7212123742bf6d36 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 25 Sep 2019 22:27:48 +0800 Subject: [PATCH 07/10] add NullIntolerant trait --- .../expressions/postgreSQL/PostgreCastStringToBoolean.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastStringToBoolean.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastStringToBoolean.scala index 515f74f2df7be..3a66d0184a15f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastStringToBoolean.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastStringToBoolean.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.catalyst.util.postgreSQL.StringUtils import org.apache.spark.sql.types.{BooleanType, DataType, StringType} import org.apache.spark.unsafe.types.UTF8String -case class PostgreCastStringToBoolean(child: Expression) extends UnaryExpression { +case class PostgreCastStringToBoolean(child: Expression) + extends UnaryExpression with NullIntolerant { override def checkInputDataTypes(): TypeCheckResult = { if (child.dataType == StringType) { From af97b9902e423de9bf801cd0d478e7ad6a44f8ea Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 26 Sep 2019 00:01:42 +0800 Subject: [PATCH 08/10] address comments --- .../apache/spark/sql/catalyst/util/postgreSQL/StringUtils.scala | 2 +- .../org/apache/spark/sql/catalyst/expressions/CastSuite.scala | 2 +- .../spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/postgreSQL/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/postgreSQL/StringUtils.scala index 80d95a466f9f5..dbc058451a442 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/postgreSQL/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/postgreSQL/StringUtils.scala @@ -22,7 +22,7 @@ import org.apache.spark.unsafe.types.UTF8String object StringUtils { // "true", "yes", "1", "false", "no", "0", and unique prefixes of these strings are accepted. private[this] val trueStrings = - Set("true", "tru", "tr", "t", "yes", "ye", "y", "on", "1").map (UTF8String.fromString) + Set("true", "tru", "tr", "t", "yes", "ye", "y", "on", "1").map(UTF8String.fromString) private[this] val falseStrings = Set("false", "fals", "fal", "fa", "f", "no", "n", "off", "of", "0").map(UTF8String.fromString) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 38aacb2c97b1c..b3a53dc9fedf2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -818,7 +818,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { "interval 1 years 3 months -3 days") } - test("cast string to boolean with Spark dialect") { + test("cast string to boolean") { checkCast("t", true) checkCast("true", true) checkCast("tRUe", true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala index a8d5a3ad52fda..175904da21969 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala @@ -24,7 +24,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(PostgreCastStringToBoolean(Literal(v)), expected) } - test("cast string to boolean with PostgreSQL dialect") { + test("cast string to boolean") { checkPostgreCastStringToBoolean("true", true) checkPostgreCastStringToBoolean("tru", true) checkPostgreCastStringToBoolean("tr", true) From b5181145f157b1d874cde3acfad9e20efb0c2bf0 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 26 Sep 2019 12:41:49 +0800 Subject: [PATCH 09/10] address more comments --- .../apache/spark/sql/PostgreSQLDialectQuerySuite.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala index 4b800d4743030..1354dcfda45fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala @@ -23,13 +23,12 @@ import org.apache.spark.sql.test.SharedSparkSession class PostgreSQLDialectQuerySuite extends QueryTest with SharedSparkSession { override def sparkConf: SparkConf = - super.sparkConf - .set(SQLConf.DIALECT.key, SQLConf.Dialect.POSTGRESQL.toString) + super.sparkConf.set(SQLConf.DIALECT.key, SQLConf.Dialect.POSTGRESQL.toString) test("cast string to boolean") { - Seq("true", "tru", "tr", "t", " tRue ", " tRu ", "yes", "ye", - "y", "1", "on").foreach { input => - checkAnswer(sql(s"select cast('$input' as boolean)"), Row(true)) + Seq("true", "tru", "tr", "t", " tRue ", " tRu ", "yes", "ye", + "y", "1", "on").foreach { input => + checkAnswer(sql(s"select cast('$input' as boolean)"), Row(true)) } Seq("false", "fals", "fal", "fa", "f", " fAlse ", " fAls ", "no", "n", "0", "off", "of").foreach { input => From 75aeda757b10b8601cb386a53a87d888fc6556e1 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 26 Sep 2019 15:39:07 +0800 Subject: [PATCH 10/10] one more comment --- .../postgreSQL/PostgreCastStringToBoolean.scala | 8 +++++--- .../spark/sql/catalyst/util/postgreSQL/StringUtils.scala | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastStringToBoolean.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastStringToBoolean.scala index 3a66d0184a15f..0e87707d01e47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastStringToBoolean.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastStringToBoolean.scala @@ -37,7 +37,7 @@ case class PostgreCastStringToBoolean(child: Expression) } override def nullSafeEval(input: Any): Any = { - val s = input.asInstanceOf[UTF8String] + val s = input.asInstanceOf[UTF8String].trim().toLowerCase() if (StringUtils.isTrueString(s)) { true } else if (StringUtils.isFalseString(s)) { @@ -51,14 +51,16 @@ case class PostgreCastStringToBoolean(child: Expression) val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}" val eval = child.genCode(ctx) val javaType = JavaCode.javaType(dataType) + val preprocessedString = ctx.freshName("preprocessedString") val castCode = code""" boolean ${ev.isNull} = ${eval.isNull}; $javaType ${ev.value} = false; if (!${eval.isNull}) { - if ($stringUtils.isTrueString(${eval.value})) { + UTF8String $preprocessedString = ${eval.value}.trim().toLowerCase(); + if ($stringUtils.isTrueString($preprocessedString)) { ${ev.value} = true; - } else if ($stringUtils.isFalseString(${eval.value})) { + } else if ($stringUtils.isFalseString($preprocessedString)) { ${ev.value} = false; } else { ${ev.isNull} = true; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/postgreSQL/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/postgreSQL/StringUtils.scala index dbc058451a442..1ae15df29d6e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/postgreSQL/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/postgreSQL/StringUtils.scala @@ -27,7 +27,7 @@ object StringUtils { private[this] val falseStrings = Set("false", "fals", "fal", "fa", "f", "no", "n", "off", "of", "0").map(UTF8String.fromString) - def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.trim().toLowerCase()) + def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s) - def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.trim().toLowerCase()) + def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s) }