Skip to content

Commit 609ba5f

Browse files
viiryacloud-fan
authored andcommitted
[SPARK-20399][SQL] Add a config to fallback string literal parsing consistent with old sql parser behavior
## What changes were proposed in this pull request? The new SQL parser is introduced into Spark 2.0. All string literals are unescaped in parser. Seems it bring an issue regarding the regex pattern string. The following codes can reproduce it: val data = Seq("\u0020\u0021\u0023", "abc") val df = data.toDF() // 1st usage: works in 1.6 // Let parser parse pattern string val rlike1 = df.filter("value rlike '^\\x20[\\x20-\\x23]+$'") // 2nd usage: works in 1.6, 2.x // Call Column.rlike so the pattern string is a literal which doesn't go through parser val rlike2 = df.filter($"value".rlike("^\\x20[\\x20-\\x23]+$")) // In 2.x, we need add backslashes to make regex pattern parsed correctly val rlike3 = df.filter("value rlike '^\\\\x20[\\\\x20-\\\\x23]+$'") Follow the discussion in #17736, this patch adds a config to fallback to 1.6 string literal parsing and mitigate migration issue. ## How was this patch tested? Jenkins tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh <[email protected]> Closes #17887 from viirya/add-config-fallback-string-parsing.
1 parent 04901dd commit 609ba5f

File tree

9 files changed

+171
-42
lines changed

9 files changed

+171
-42
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class SessionCatalog(
7373
functionRegistry,
7474
conf,
7575
new Configuration(),
76-
CatalystSqlParser,
76+
new CatalystSqlParser(conf),
7777
DummyFunctionResourceLoader)
7878
}
7979

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ abstract class StringRegexExpression extends BinaryExpression
8686
escape character, the following character is matched literally. It is invalid to escape
8787
any other character.
8888
89+
Since Spark 2.0, string literals are unescaped in our SQL parser. For example, in order
90+
to match "\abc", the pattern should be "\\abc".
91+
92+
When SQL config 'spark.sql.parser.escapedStringLiterals' is enabled, it fallbacks
93+
to Spark 1.6 behavior regarding string literal parsing. For example, if the config is
94+
enabled, the pattern to match "\abc" should be "\abc".
95+
8996
Examples:
9097
> SELECT '%SystemDrive%\Users\John' _FUNC_ '\%SystemDrive\%\\Users%'
9198
true
@@ -144,7 +151,31 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
144151
}
145152

146153
@ExpressionDescription(
147-
usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.")
154+
usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.",
155+
extended = """
156+
Arguments:
157+
str - a string expression
158+
regexp - a string expression. The pattern string should be a Java regular expression.
159+
160+
Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL parser.
161+
For example, to match "\abc", a regular expression for `regexp` can be "^\\abc$".
162+
163+
There is a SQL config 'spark.sql.parser.escapedStringLiterals' that can be used to fallback
164+
to the Spark 1.6 behavior regarding string literal parsing. For example, if the config is
165+
enabled, the `regexp` that can match "\abc" is "^\abc$".
166+
167+
Examples:
168+
When spark.sql.parser.escapedStringLiterals is disabled (default).
169+
> SELECT '%SystemDrive%\Users\John' _FUNC_ '%SystemDrive%\\Users.*'
170+
true
171+
172+
When spark.sql.parser.escapedStringLiterals is enabled.
173+
> SELECT '%SystemDrive%\Users\John' _FUNC_ '%SystemDrive%\Users.*'
174+
true
175+
176+
See also:
177+
Use LIKE to match with simple string pattern.
178+
""")
148179
case class RLike(left: Expression, right: Expression) extends StringRegexExpression {
149180

150181
override def escape(v: String): String = v

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
3636
import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
3737
import org.apache.spark.sql.catalyst.plans._
3838
import org.apache.spark.sql.catalyst.plans.logical._
39+
import org.apache.spark.sql.internal.SQLConf
3940
import org.apache.spark.sql.types._
4041
import org.apache.spark.unsafe.types.CalendarInterval
4142
import org.apache.spark.util.random.RandomSampler
@@ -44,9 +45,11 @@ import org.apache.spark.util.random.RandomSampler
4445
* The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or
4546
* TableIdentifier.
4647
*/
47-
class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
48+
class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging {
4849
import ParserUtils._
4950

51+
def this() = this(new SQLConf())
52+
5053
protected def typedVisit[T](ctx: ParseTree): T = {
5154
ctx.accept(this).asInstanceOf[T]
5255
}
@@ -1423,7 +1426,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
14231426
* Special characters can be escaped by using Hive/C-style escaping.
14241427
*/
14251428
private def createString(ctx: StringLiteralContext): String = {
1426-
ctx.STRING().asScala.map(string).mkString
1429+
if (conf.escapedStringLiterals) {
1430+
ctx.STRING().asScala.map(stringWithoutUnescape).mkString
1431+
} else {
1432+
ctx.STRING().asScala.map(string).mkString
1433+
}
14271434
}
14281435

14291436
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
2626
import org.apache.spark.sql.catalyst.expressions.Expression
2727
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2828
import org.apache.spark.sql.catalyst.trees.Origin
29+
import org.apache.spark.sql.internal.SQLConf
2930
import org.apache.spark.sql.types.{DataType, StructType}
3031

3132
/**
@@ -121,8 +122,13 @@ abstract class AbstractSqlParser extends ParserInterface with Logging {
121122
/**
122123
* Concrete SQL parser for Catalyst-only SQL statements.
123124
*/
125+
class CatalystSqlParser(conf: SQLConf) extends AbstractSqlParser {
126+
val astBuilder = new AstBuilder(conf)
127+
}
128+
129+
/** For test-only. */
124130
object CatalystSqlParser extends AbstractSqlParser {
125-
val astBuilder = new AstBuilder
131+
val astBuilder = new AstBuilder(new SQLConf())
126132
}
127133

128134
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ object ParserUtils {
6868
/** Convert a string node into a string. */
6969
def string(node: TerminalNode): String = unescapeSQLString(node.getText)
7070

71+
/** Convert a string node into a string without unescaping. */
72+
def stringWithoutUnescape(node: TerminalNode): String = {
73+
// STRING parser rule forces that the input always has quotes at the starting and ending.
74+
node.getText.slice(1, node.getText.size - 1)
75+
}
76+
7177
/** Get the origin (line and position) of the token. */
7278
def position(token: Token): Origin = {
7379
val opt = Option(token)

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,14 @@ object SQLConf {
196196
.booleanConf
197197
.createWithDefault(true)
198198

199+
val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals")
200+
.internal()
201+
.doc("When true, string literals (including regex patterns) remain escaped in our SQL " +
202+
"parser. The default is false since Spark 2.0. Setting it to true can restore the behavior " +
203+
"prior to Spark 2.0.")
204+
.booleanConf
205+
.createWithDefault(false)
206+
199207
val PARQUET_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.parquet.mergeSchema")
200208
.doc("When true, the Parquet data source merges schemas collected from all data files, " +
201209
"otherwise the schema is picked from the summary file or a random data file " +
@@ -917,6 +925,8 @@ class SQLConf extends Serializable with Logging {
917925

918926
def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED)
919927

928+
def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS)
929+
920930
/**
921931
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
922932
* identifiers are equal.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala

Lines changed: 92 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _}
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
2525
import org.apache.spark.sql.catalyst.plans.PlanTest
26+
import org.apache.spark.sql.internal.SQLConf
2627
import org.apache.spark.sql.types._
2728
import org.apache.spark.unsafe.types.CalendarInterval
2829

@@ -39,12 +40,17 @@ class ExpressionParserSuite extends PlanTest {
3940
import org.apache.spark.sql.catalyst.dsl.expressions._
4041
import org.apache.spark.sql.catalyst.dsl.plans._
4142

42-
def assertEqual(sqlCommand: String, e: Expression): Unit = {
43-
compareExpressions(parseExpression(sqlCommand), e)
43+
val defaultParser = CatalystSqlParser
44+
45+
def assertEqual(
46+
sqlCommand: String,
47+
e: Expression,
48+
parser: ParserInterface = defaultParser): Unit = {
49+
compareExpressions(parser.parseExpression(sqlCommand), e)
4450
}
4551

4652
def intercept(sqlCommand: String, messages: String*): Unit = {
47-
val e = intercept[ParseException](parseExpression(sqlCommand))
53+
val e = intercept[ParseException](defaultParser.parseExpression(sqlCommand))
4854
messages.foreach { message =>
4955
assert(e.message.contains(message))
5056
}
@@ -101,7 +107,7 @@ class ExpressionParserSuite extends PlanTest {
101107
test("long binary logical expressions") {
102108
def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = {
103109
val sql = (1 to 1000).map(x => s"$x == $x").mkString(op)
104-
val e = parseExpression(sql)
110+
val e = defaultParser.parseExpression(sql)
105111
assert(e.collect { case _: EqualTo => true }.size === 1000)
106112
assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999)
107113
}
@@ -160,6 +166,15 @@ class ExpressionParserSuite extends PlanTest {
160166
assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%"))
161167
}
162168

169+
test("like expressions with ESCAPED_STRING_LITERALS = true") {
170+
val conf = new SQLConf()
171+
conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, "true")
172+
val parser = new CatalystSqlParser(conf)
173+
assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$", parser)
174+
assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\", parser)
175+
assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n", parser)
176+
}
177+
163178
test("is null expressions") {
164179
assertEqual("a is null", 'a.isNull)
165180
assertEqual("a is not null", 'a.isNotNull)
@@ -418,38 +433,79 @@ class ExpressionParserSuite extends PlanTest {
418433
}
419434

420435
test("strings") {
421-
// Single Strings.
422-
assertEqual("\"hello\"", "hello")
423-
assertEqual("'hello'", "hello")
424-
425-
// Multi-Strings.
426-
assertEqual("\"hello\" 'world'", "helloworld")
427-
assertEqual("'hello' \" \" 'world'", "hello world")
428-
429-
// 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a
430-
// regular '%'; to get the correct result you need to add another escaped '\'.
431-
// TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method?
432-
assertEqual("'pattern%'", "pattern%")
433-
assertEqual("'no-pattern\\%'", "no-pattern\\%")
434-
assertEqual("'pattern\\\\%'", "pattern\\%")
435-
assertEqual("'pattern\\\\\\%'", "pattern\\\\%")
436-
437-
// Escaped characters.
438-
// See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html
439-
assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00')
440-
assertEqual("'\\''", "\'") // Single quote
441-
assertEqual("'\\\"'", "\"") // Double quote
442-
assertEqual("'\\b'", "\b") // Backspace
443-
assertEqual("'\\n'", "\n") // Newline
444-
assertEqual("'\\r'", "\r") // Carriage return
445-
assertEqual("'\\t'", "\t") // Tab character
446-
assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows)
447-
448-
// Octals
449-
assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!")
450-
451-
// Unicode
452-
assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)")
436+
Seq(true, false).foreach { escape =>
437+
val conf = new SQLConf()
438+
conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, escape.toString)
439+
val parser = new CatalystSqlParser(conf)
440+
441+
// tests that have same result whatever the conf is
442+
// Single Strings.
443+
assertEqual("\"hello\"", "hello", parser)
444+
assertEqual("'hello'", "hello", parser)
445+
446+
// Multi-Strings.
447+
assertEqual("\"hello\" 'world'", "helloworld", parser)
448+
assertEqual("'hello' \" \" 'world'", "hello world", parser)
449+
450+
// 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a
451+
// regular '%'; to get the correct result you need to add another escaped '\'.
452+
// TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method?
453+
assertEqual("'pattern%'", "pattern%", parser)
454+
assertEqual("'no-pattern\\%'", "no-pattern\\%", parser)
455+
456+
// tests that have different result regarding the conf
457+
if (escape) {
458+
// When SQLConf.ESCAPED_STRING_LITERALS is enabled, string literal parsing fallbacks to
459+
// Spark 1.6 behavior.
460+
461+
// 'LIKE' string literals.
462+
assertEqual("'pattern\\\\%'", "pattern\\\\%", parser)
463+
assertEqual("'pattern\\\\\\%'", "pattern\\\\\\%", parser)
464+
465+
// Escaped characters.
466+
assertEqual("'\0'", "\u0000", parser) // ASCII NUL (X'00')
467+
468+
// Note: Single quote follows 1.6 parsing behavior when ESCAPED_STRING_LITERALS is enabled.
469+
val e = intercept[ParseException](parser.parseExpression("'\''"))
470+
assert(e.message.contains("extraneous input '''"))
471+
472+
assertEqual("'\"'", "\"", parser) // Double quote
473+
assertEqual("'\b'", "\b", parser) // Backspace
474+
assertEqual("'\n'", "\n", parser) // Newline
475+
assertEqual("'\r'", "\r", parser) // Carriage return
476+
assertEqual("'\t'", "\t", parser) // Tab character
477+
478+
// Octals
479+
assertEqual("'\110\145\154\154\157\041'", "Hello!", parser)
480+
// Unicode
481+
assertEqual("'\u0057\u006F\u0072\u006C\u0064\u0020\u003A\u0029'", "World :)", parser)
482+
} else {
483+
// Default behavior
484+
485+
// 'LIKE' string literals.
486+
assertEqual("'pattern\\\\%'", "pattern\\%", parser)
487+
assertEqual("'pattern\\\\\\%'", "pattern\\\\%", parser)
488+
489+
// Escaped characters.
490+
// See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html
491+
assertEqual("'\\0'", "\u0000", parser) // ASCII NUL (X'00')
492+
assertEqual("'\\''", "\'", parser) // Single quote
493+
assertEqual("'\\\"'", "\"", parser) // Double quote
494+
assertEqual("'\\b'", "\b", parser) // Backspace
495+
assertEqual("'\\n'", "\n", parser) // Newline
496+
assertEqual("'\\r'", "\r", parser) // Carriage return
497+
assertEqual("'\\t'", "\t", parser) // Tab character
498+
assertEqual("'\\Z'", "\u001A", parser) // ASCII 26 - CTRL + Z (EOF on windows)
499+
500+
// Octals
501+
assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!", parser)
502+
503+
// Unicode
504+
assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)",
505+
parser)
506+
}
507+
508+
}
453509
}
454510

455511
test("intervals") {

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser {
5252
/**
5353
* Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier.
5454
*/
55-
class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
55+
class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
5656
import org.apache.spark.sql.catalyst.parser.ParserUtils._
5757

5858
/**

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec}
2626
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange}
2727
import org.apache.spark.sql.execution.streaming.MemoryStream
2828
import org.apache.spark.sql.functions._
29+
import org.apache.spark.sql.internal.SQLConf
2930
import org.apache.spark.sql.test.SharedSQLContext
3031
import org.apache.spark.sql.types._
3132

@@ -1168,6 +1169,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
11681169
val ds = Seq(WithMapInOption(Some(Map(1 -> 1)))).toDS()
11691170
checkDataset(ds, WithMapInOption(Some(Map(1 -> 1))))
11701171
}
1172+
1173+
test("SPARK-20399: do not unescaped regex pattern when ESCAPED_STRING_LITERALS is enabled") {
1174+
withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> "true") {
1175+
val data = Seq("\u0020\u0021\u0023", "abc")
1176+
val df = data.toDF()
1177+
val rlike1 = df.filter("value rlike '^\\x20[\\x20-\\x23]+$'")
1178+
val rlike2 = df.filter($"value".rlike("^\\x20[\\x20-\\x23]+$"))
1179+
val rlike3 = df.filter("value rlike '^\\\\x20[\\\\x20-\\\\x23]+$'")
1180+
checkAnswer(rlike1, rlike2)
1181+
assert(rlike3.count() == 0)
1182+
}
1183+
}
11711184
}
11721185

11731186
case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String])

0 commit comments

Comments
 (0)