Skip to content

Commit 0e7f281

Browse files
committed
dialect
1 parent f725d47 commit 0e7f281

File tree

8 files changed

+108
-56
lines changed

8 files changed

+108
-56
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,8 +677,10 @@ object TypeCoercion {
677677
case d: Divide if d.dataType == DoubleType => d
678678
case d: Divide if d.dataType.isInstanceOf[DecimalType] => d
679679
case Divide(left, right) if isNumericOrNull(left) && isNumericOrNull(right) =>
680+
val preferIntegralDivision =
681+
conf.getConf(SQLConf.DIALECT) == SQLConf.Dialect.POSTGRESQL.toString
680682
(left.dataType, right.dataType) match {
681-
case (_: IntegralType, _: IntegralType) if conf.preferIntegralDivision =>
683+
case (_: IntegralType, _: IntegralType) if preferIntegralDivision =>
682684
IntegralDivide(left, right)
683685
case _ =>
684686
Divide(Cast(left, DoubleType), Cast(right, DoubleType))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -391,10 +391,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
391391
// UDFToBoolean
392392
private[this] def castToBoolean(from: DataType): Any => Any = from match {
393393
case StringType =>
394+
val dialect = SQLConf.get.getConf(SQLConf.DIALECT)
394395
buildCast[UTF8String](_, s => {
395-
if (StringUtils.isTrueString(s)) {
396+
if (StringUtils.isTrueString(s, dialect)) {
396397
true
397-
} else if (StringUtils.isFalseString(s)) {
398+
} else if (StringUtils.isFalseString(s, dialect)) {
398399
false
399400
} else {
400401
null
@@ -1250,11 +1251,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
12501251
private[this] def castToBooleanCode(from: DataType): CastFunction = from match {
12511252
case StringType =>
12521253
val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}"
1254+
val dialect = SQLConf.get.getConf(SQLConf.DIALECT)
12531255
(c, evPrim, evNull) =>
12541256
code"""
1255-
if ($stringUtils.isTrueString($c)) {
1257+
if ($stringUtils.isTrueString($c, "$dialect")) {
12561258
$evPrim = true;
1257-
} else if ($stringUtils.isFalseString($c)) {
1259+
} else if ($stringUtils.isFalseString($c, "$dialect")) {
12581260
$evPrim = false;
12591261
} else {
12601262
$evNull = true;

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,34 @@ object StringUtils extends Logging {
6565
"(?s)" + out.result() // (?s) enables dotall mode, causing "." to match new lines
6666
}
6767

68-
// "true", "yes", "1", "false", "no", "0", and unique prefixes of these strings are accepted.
6968
private[this] val trueStrings =
70-
Set("true", "tru", "tr", "t", "yes", "ye", "y", "on", "1").map(UTF8String.fromString)
69+
Set("t", "true", "y", "yes", "1").map(UTF8String.fromString)
70+
// "true", "yes", "1", "false", "no", "0", and unique prefixes of these strings are accepted.
71+
private[this] val trueStringsOfPostgreSQL =
72+
Set("true", "tru", "tr", "t", "yes", "ye", "y", "on", "1").map (UTF8String.fromString)
7173

7274
private[this] val falseStrings =
75+
Set("f", "false", "n", "no", "0").map(UTF8String.fromString)
76+
private[this] val falseStringsOfPostgreSQL =
7377
Set("false", "fals", "fal", "fa", "f", "no", "n", "off", "of", "0").map(UTF8String.fromString)
74-
7578
// scalastyle:off caselocale
76-
def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase.trim())
77-
def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase.trim())
79+
def isTrueString(s: UTF8String, dialect: String): Boolean = {
80+
SQLConf.Dialect.withName(dialect) match {
81+
case SQLConf.Dialect.SPARK =>
82+
trueStrings.contains(s.toLowerCase)
83+
case SQLConf.Dialect.POSTGRESQL =>
84+
trueStringsOfPostgreSQL.contains(s.toLowerCase.trim())
85+
}
86+
}
87+
88+
def isFalseString(s: UTF8String, dialect: String): Boolean = {
89+
SQLConf.Dialect.withName(dialect) match {
90+
case SQLConf.Dialect.SPARK =>
91+
falseStrings.contains(s.toLowerCase)
92+
case SQLConf.Dialect.POSTGRESQL =>
93+
falseStringsOfPostgreSQL.contains(s.toLowerCase.trim())
94+
}
95+
}
7896
// scalastyle:on caselocale
7997

8098
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,12 +1589,22 @@ object SQLConf {
15891589
.booleanConf
15901590
.createWithDefault(false)
15911591

1592-
val PREFER_INTEGRAL_DIVISION = buildConf("spark.sql.function.preferIntegralDivision")
1593-
.internal()
1594-
.doc("When true, will perform integral division with the / operator " +
1595-
"if both sides are integral types. This is for PostgreSQL test cases only.")
1596-
.booleanConf
1597-
.createWithDefault(false)
1592+
object Dialect extends Enumeration {
1593+
val SPARK, POSTGRESQL = Value
1594+
}
1595+
1596+
val DIALECT =
1597+
buildConf("spark.sql.dialect")
1598+
.doc("The specific features of the SQL language to be adopted, which are available when " +
1599+
"accessing the given database. Currently, Spark supports two database dialects, `Spark` " +
1600+
"and `PostgreSQL`. With `PostgreSQL` dialect, Spark will: " +
1601+
"1. perform integral division with the / operator if both sides are integral types; " +
1602+
"2. accept \"true\", \"yes\", \"1\", \"false\", \"no\", \"0\", and unique prefixes as " +
1603+
"input and trim input for the boolean data type.")
1604+
.stringConf
1605+
.transform(_.toUpperCase(Locale.ROOT))
1606+
.checkValues(Dialect.values.map(_.toString))
1607+
.createWithDefault(Dialect.SPARK.toString)
15981608

15991609
val ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION =
16001610
buildConf("spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation")
@@ -2418,8 +2428,6 @@ class SQLConf extends Serializable with Logging {
24182428

24192429
def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING)
24202430

2421-
def preferIntegralDivision: Boolean = getConf(PREFER_INTEGRAL_DIVISION)
2422-
24232431
def allowCreatingManagedTableUsingNonemptyLocation: Boolean =
24242432
getConf(ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION)
24252433

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,15 +1483,15 @@ class TypeCoercionSuite extends AnalysisTest {
14831483

14841484
test("SPARK-28395 Division operator support integral division") {
14851485
val rules = Seq(FunctionArgumentConversion, Division(conf))
1486-
Seq(true, false).foreach { preferIntegralDivision =>
1487-
withSQLConf(SQLConf.PREFER_INTEGRAL_DIVISION.key -> s"$preferIntegralDivision") {
1488-
val result1 = if (preferIntegralDivision) {
1486+
Seq(SQLConf.Dialect.SPARK, SQLConf.Dialect.POSTGRESQL).foreach { dialect =>
1487+
withSQLConf(SQLConf.DIALECT.key -> dialect.toString) {
1488+
val result1 = if (dialect == SQLConf.Dialect.POSTGRESQL) {
14891489
IntegralDivide(1L, 1L)
14901490
} else {
14911491
Divide(Cast(1L, DoubleType), Cast(1L, DoubleType))
14921492
}
14931493
ruleTest(rules, Divide(1L, 1L), result1)
1494-
val result2 = if (preferIntegralDivision) {
1494+
val result2 = if (dialect == SQLConf.Dialect.POSTGRESQL) {
14951495
IntegralDivide(1, Cast(1, ShortType))
14961496
} else {
14971497
Divide(Cast(1, DoubleType), Cast(Cast(1, ShortType), DoubleType))

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala

Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -818,37 +818,60 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
818818
"interval 1 years 3 months -3 days")
819819
}
820820

821-
test("cast string to boolean") {
822-
checkCast("true", true)
823-
checkCast("tru", true)
824-
checkCast("tr", true)
825-
checkCast("t", true)
826-
checkCast("tRUe", true)
827-
checkCast(" tRue ", true)
828-
checkCast(" tRu ", true)
829-
checkCast("yes", true)
830-
checkCast("ye", true)
831-
checkCast("y", true)
832-
checkCast("1", true)
833-
checkCast("on", true)
834-
835-
checkCast("false", false)
836-
checkCast("fals", false)
837-
checkCast("fal", false)
838-
checkCast("fa", false)
839-
checkCast("f", false)
840-
checkCast(" fAlse ", false)
841-
checkCast(" fAls ", false)
842-
checkCast(" FAlsE ", false)
843-
checkCast("no", false)
844-
checkCast("n", false)
845-
checkCast("0", false)
846-
checkCast("off", false)
847-
checkCast("of", false)
848-
849-
checkEvaluation(cast("o", BooleanType), null)
850-
checkEvaluation(cast("abc", BooleanType), null)
851-
checkEvaluation(cast("", BooleanType), null)
821+
test("cast string to boolean with Spark dialect") {
822+
withSQLConf(SQLConf.DIALECT.key -> SQLConf.Dialect.SPARK.toString) {
823+
checkCast("t", true)
824+
checkCast("true", true)
825+
checkCast("tRUe", true)
826+
checkCast("y", true)
827+
checkCast("yes", true)
828+
checkCast("1", true)
829+
830+
checkCast("f", false)
831+
checkCast("false", false)
832+
checkCast("FAlsE", false)
833+
checkCast("n", false)
834+
checkCast("no", false)
835+
checkCast("0", false)
836+
837+
checkEvaluation(cast("abc", BooleanType), null)
838+
checkEvaluation(cast("", BooleanType), null)
839+
}
840+
}
841+
842+
test("cast string to boolean with PostgreSQL dialect") {
843+
withSQLConf(SQLConf.DIALECT.key -> SQLConf.Dialect.POSTGRESQL.toString) {
844+
checkCast("true", true)
845+
checkCast("tru", true)
846+
checkCast("tr", true)
847+
checkCast("t", true)
848+
checkCast("tRUe", true)
849+
checkCast(" tRue ", true)
850+
checkCast(" tRu ", true)
851+
checkCast("yes", true)
852+
checkCast("ye", true)
853+
checkCast("y", true)
854+
checkCast("1", true)
855+
checkCast("on", true)
856+
857+
checkCast("false", false)
858+
checkCast("fals", false)
859+
checkCast("fal", false)
860+
checkCast("fa", false)
861+
checkCast("f", false)
862+
checkCast(" fAlse ", false)
863+
checkCast(" fAls ", false)
864+
checkCast(" FAlsE ", false)
865+
checkCast("no", false)
866+
checkCast("n", false)
867+
checkCast("0", false)
868+
checkCast("off", false)
869+
checkCast("of", false)
870+
871+
checkEvaluation(cast("o", BooleanType), null)
872+
checkEvaluation(cast("abc", BooleanType), null)
873+
checkEvaluation(cast("", BooleanType), null)
874+
}
852875
}
853876

854877
test("SPARK-16729 type checking for casting to date type") {

sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
311311
// PostgreSQL enabled cartesian product by default.
312312
localSparkSession.conf.set(SQLConf.CROSS_JOINS_ENABLED.key, true)
313313
localSparkSession.conf.set(SQLConf.ANSI_ENABLED.key, true)
314-
localSparkSession.conf.set(SQLConf.PREFER_INTEGRAL_DIVISION.key, true)
315-
localSparkSession.conf.set(SQLConf.ANSI_ENABLED.key, true)
314+
localSparkSession.conf.set(SQLConf.DIALECT.key, SQLConf.Dialect.POSTGRESQL.toString)
316315
case _ =>
317316
}
318317

sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite {
111111
// PostgreSQL enabled cartesian product by default.
112112
statement.execute(s"SET ${SQLConf.CROSS_JOINS_ENABLED.key} = true")
113113
statement.execute(s"SET ${SQLConf.ANSI_ENABLED.key} = true")
114-
statement.execute(s"SET ${SQLConf.PREFER_INTEGRAL_DIVISION.key} = true")
114+
statement.execute(s"SET ${SQLConf.DIALECT.key} = ${SQLConf.Dialect.POSTGRESQL.toString}")
115115
case _ =>
116116
}
117117

0 commit comments

Comments
 (0)