Skip to content

Commit d0b2c22

Browse files
committed
Add a config to fallback string literal parsing consistent with old sql parser behavior.
1 parent b0a1e93 commit d0b2c22

File tree

8 files changed

+103
-9
lines changed

8 files changed

+103
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class SessionCatalog(
7373
functionRegistry,
7474
conf,
7575
new Configuration(),
76-
CatalystSqlParser,
76+
new CatalystSqlParser(conf),
7777
DummyFunctionResourceLoader)
7878
}
7979

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
3636
import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
3737
import org.apache.spark.sql.catalyst.plans._
3838
import org.apache.spark.sql.catalyst.plans.logical._
39+
import org.apache.spark.sql.internal.SQLConf
3940
import org.apache.spark.sql.types._
4041
import org.apache.spark.unsafe.types.CalendarInterval
4142
import org.apache.spark.util.random.RandomSampler
@@ -44,9 +45,11 @@ import org.apache.spark.util.random.RandomSampler
4445
* The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or
4546
* TableIdentifier.
4647
*/
47-
class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
48+
class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging {
4849
import ParserUtils._
4950

51+
def this() = this(new SQLConf())
52+
5053
protected def typedVisit[T](ctx: ParseTree): T = {
5154
ctx.accept(this).asInstanceOf[T]
5255
}
@@ -1406,7 +1409,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
14061409
* Special characters can be escaped by using Hive/C-style escaping.
14071410
*/
14081411
private def createString(ctx: StringLiteralContext): String = {
1409-
ctx.STRING().asScala.map(string).mkString
1412+
if (conf.noUnescapedStringLiteral) {
1413+
ctx.STRING().asScala.map(stringWithoutUnescape).mkString
1414+
} else {
1415+
ctx.STRING().asScala.map(string).mkString
1416+
}
14101417
}
14111418

14121419
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
2626
import org.apache.spark.sql.catalyst.expressions.Expression
2727
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2828
import org.apache.spark.sql.catalyst.trees.Origin
29+
import org.apache.spark.sql.internal.SQLConf
2930
import org.apache.spark.sql.types.{DataType, StructType}
3031

3132
/**
@@ -120,8 +121,13 @@ abstract class AbstractSqlParser extends ParserInterface with Logging {
120121
/**
121122
* Concrete SQL parser for Catalyst-only SQL statements.
122123
*/
124+
class CatalystSqlParser(conf: SQLConf) extends AbstractSqlParser {
125+
val astBuilder = new AstBuilder(conf)
126+
}
127+
128+
/** For test-only. */
123129
object CatalystSqlParser extends AbstractSqlParser {
124-
val astBuilder = new AstBuilder
130+
val astBuilder = new AstBuilder(new SQLConf())
125131
}
126132

127133
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ object ParserUtils {
6868
/** Convert a string node into a string. */
6969
def string(node: TerminalNode): String = unescapeSQLString(node.getText)
7070

71+
/** Convert a string node into a string without unescaping. */
72+
def stringWithoutUnescape(node: TerminalNode): String = {
73+
node.getText.slice(1, node.getText.size - 1)
74+
}
75+
7176
/** Get the origin (line and position) of the token. */
7277
def position(token: Token): Origin = {
7378
val opt = Option(token)

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,14 @@ object SQLConf {
196196
.booleanConf
197197
.createWithDefault(true)
198198

199+
val NO_UNESCAPED_SQL_STRING = buildConf("spark.sql.noUnescapedStringLiteral")
200+
.internal()
201+
.doc("Since Spark 2.0, we use unescaped SQL string for string literals including regex. " +
202+
"It is different than 1.6 behavior. Enabling this config can use no unescaped SQL string " +
203+
"literals and mitigate migration problem.")
204+
.booleanConf
205+
.createWithDefault(false)
206+
199207
val PARQUET_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.parquet.mergeSchema")
200208
.doc("When true, the Parquet data source merges schemas collected from all data files, " +
201209
"otherwise the schema is picked from the summary file or a random data file " +
@@ -911,6 +919,8 @@ class SQLConf extends Serializable with Logging {
911919

912920
def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED)
913921

922+
def noUnescapedStringLiteral: Boolean = getConf(NO_UNESCAPED_SQL_STRING)
923+
914924
/**
915925
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
916926
* identifiers are equal.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _}
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
2525
import org.apache.spark.sql.catalyst.plans.PlanTest
26+
import org.apache.spark.sql.internal.SQLConf
2627
import org.apache.spark.sql.types._
2728
import org.apache.spark.unsafe.types.CalendarInterval
2829

@@ -39,12 +40,17 @@ class ExpressionParserSuite extends PlanTest {
3940
import org.apache.spark.sql.catalyst.dsl.expressions._
4041
import org.apache.spark.sql.catalyst.dsl.plans._
4142

42-
def assertEqual(sqlCommand: String, e: Expression): Unit = {
43-
compareExpressions(parseExpression(sqlCommand), e)
43+
val defaultParser = CatalystSqlParser
44+
45+
def assertEqual(
46+
sqlCommand: String,
47+
e: Expression,
48+
parser: ParserInterface = defaultParser): Unit = {
49+
compareExpressions(parser.parseExpression(sqlCommand), e)
4450
}
4551

4652
def intercept(sqlCommand: String, messages: String*): Unit = {
47-
val e = intercept[ParseException](parseExpression(sqlCommand))
53+
val e = intercept[ParseException](defaultParser.parseExpression(sqlCommand))
4854
messages.foreach { message =>
4955
assert(e.message.contains(message))
5056
}
@@ -101,7 +107,7 @@ class ExpressionParserSuite extends PlanTest {
101107
test("long binary logical expressions") {
102108
def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = {
103109
val sql = (1 to 1000).map(x => s"$x == $x").mkString(op)
104-
val e = parseExpression(sql)
110+
val e = defaultParser.parseExpression(sql)
105111
assert(e.collect { case _: EqualTo => true }.size === 1000)
106112
assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999)
107113
}
@@ -160,6 +166,15 @@ class ExpressionParserSuite extends PlanTest {
160166
assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%"))
161167
}
162168

169+
test("like expressions with NO_UNESCAPED_SQL_STRING") {
170+
val conf = new SQLConf()
171+
conf.setConfString("spark.sql.noUnescapedStringLiteral", "true")
172+
val parser = new CatalystSqlParser(conf)
173+
assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$", parser)
174+
assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\", parser)
175+
assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n", parser)
176+
}
177+
163178
test("is null expressions") {
164179
assertEqual("a is null", 'a.isNull)
165180
assertEqual("a is not null", 'a.isNotNull)
@@ -447,6 +462,44 @@ class ExpressionParserSuite extends PlanTest {
447462
assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)")
448463
}
449464

465+
test("strings with NO_UNESCAPED_SQL_STRING") {
466+
val conf = new SQLConf()
467+
conf.setConfString("spark.sql.noUnescapedStringLiteral", "true")
468+
val parser = new CatalystSqlParser(conf)
469+
470+
// Single Strings.
471+
assertEqual("\"hello\"", "hello", parser)
472+
assertEqual("'hello'", "hello", parser)
473+
474+
// Multi-Strings.
475+
assertEqual("\"hello\" 'world'", "helloworld", parser)
476+
assertEqual("'hello' \" \" 'world'", "hello world", parser)
477+
478+
assertEqual("'pattern%'", "pattern%", parser)
479+
assertEqual("'no-pattern\\%'", "no-pattern\\%", parser)
480+
assertEqual("'pattern\\\\%'", "pattern\\\\%", parser)
481+
assertEqual("'pattern\\\\\\%'", "pattern\\\\\\%", parser)
482+
483+
// Escaped characters.
484+
assertEqual("'\0'", "\u0000", parser) // ASCII NUL (X'00')
485+
486+
// Note: Single quote follows 1.6 parsing behavior when NO_UNESCAPED_SQL_STRING is enabled.
487+
val e = intercept[ParseException](parser.parseExpression("'\''"))
488+
assert(e.message.contains("extraneous input '''"))
489+
490+
assertEqual("'\"'", "\"", parser) // Double quote
491+
assertEqual("'\b'", "\b", parser) // Backspace
492+
assertEqual("'\n'", "\n", parser) // Newline
493+
assertEqual("'\r'", "\r", parser) // Carriage return
494+
assertEqual("'\t'", "\t", parser) // Tab character
495+
496+
// Octals
497+
assertEqual("'\110\145\154\154\157\041'", "Hello!", parser)
498+
499+
// Unicode
500+
assertEqual("'\u0057\u006F\u0072\u006C\u0064\u0020\u003A\u0029'", "World :)", parser)
501+
}
502+
450503
test("intervals") {
451504
def intervalLiteral(u: String, s: String): Literal = {
452505
Literal(CalendarInterval.fromSingleUnitString(u, s))

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser {
5252
/**
5353
* Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier.
5454
*/
55-
class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
55+
class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
5656
import org.apache.spark.sql.catalyst.parser.ParserUtils._
5757

5858
/**

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec}
2626
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange}
2727
import org.apache.spark.sql.execution.streaming.MemoryStream
2828
import org.apache.spark.sql.functions._
29+
import org.apache.spark.sql.internal.SQLConf
2930
import org.apache.spark.sql.test.SharedSQLContext
3031
import org.apache.spark.sql.types._
3132

@@ -1168,6 +1169,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
11681169
val ds = Seq(WithMapInOption(Some(Map(1 -> 1)))).toDS()
11691170
checkDataset(ds, WithMapInOption(Some(Map(1 -> 1))))
11701171
}
1172+
1173+
test("do not unescaped regex pattern string") {
1174+
withSQLConf(SQLConf.NO_UNESCAPED_SQL_STRING.key -> "true") {
1175+
val data = Seq("\u0020\u0021\u0023", "abc")
1176+
val df = data.toDF()
1177+
val rlike1 = df.filter("value rlike '^\\x20[\\x20-\\x23]+$'")
1178+
val rlike2 = df.filter($"value".rlike("^\\x20[\\x20-\\x23]+$"))
1179+
val rlike3 = df.filter("value rlike '^\\\\x20[\\\\x20-\\\\x23]+$'")
1180+
checkAnswer(rlike1, rlike2)
1181+
assert(rlike3.count() == 0)
1182+
}
1183+
}
11711184
}
11721185

11731186
case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String])

0 commit comments

Comments
 (0)