Skip to content

Commit 9e386b4

Browse files
mihailom-dbcloud-fan
authored andcommitted
[SPARK-48172][SQL] Fix escaping issues in JDBCDialects
This PR is a fix of #46437. The previous PR was reverted as `LONGTEXT` is not supported by all dialects. ### What changes were proposed in this pull request? Special case escaping for MySQL and fix issues with redundant escaping for ' character. New changes introduced in the fix include change `LONGTEXT` -> `VARCHAR(50)`, as well as fix for table naming in the tests. ### Why are the changes needed? When pushing down startsWith, endsWith and contains they are converted to LIKE. This requires addition of escape characters for these expressions. Unfortunately, MySQL uses ESCAPE '\' syntax instead of ESCAPE '' which would cause errors when trying to push down. ### Does this PR introduce any user-facing change? Yes ### How was this patch tested? Tests for each existing dialect. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46588 from mihailom-db/SPARK-48172. Authored-by: Mihailo Milosevic <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent d0385c4 commit 9e386b4

File tree

12 files changed

+291
-12
lines changed

12 files changed

+291
-12
lines changed

connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
6262
connection.prepareStatement(
6363
"CREATE TABLE employee (dept INTEGER, name VARCHAR(10), salary DECIMAL(20, 2), bonus DOUBLE)")
6464
.executeUpdate()
65+
connection.prepareStatement(
66+
s"""CREATE TABLE pattern_testing_table (
67+
|pattern_testing_col VARCHAR(50)
68+
|)
69+
""".stripMargin
70+
).executeUpdate()
6571
}
6672

6773
override def testUpdateColumnType(tbl: String): Unit = {

connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,17 @@ abstract class DockerJDBCIntegrationV2Suite extends DockerJDBCIntegrationSuite {
3838
.executeUpdate()
3939
connection.prepareStatement("INSERT INTO employee VALUES (6, 'jen', 12000, 1200)")
4040
.executeUpdate()
41+
42+
connection.prepareStatement(
43+
s"""
44+
|INSERT INTO pattern_testing_table VALUES
45+
|('special_character_quote''_present'),
46+
|('special_character_quote_not_present'),
47+
|('special_character_percent%_present'),
48+
|('special_character_percent_not_present'),
49+
|('special_character_underscore_present'),
50+
|('special_character_underscorenot_present')
51+
""".stripMargin).executeUpdate()
4152
}
4253

4354
def tablePreparation(connection: Connection): Unit

connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
7070
connection.prepareStatement(
7171
"CREATE TABLE employee (dept INT, name VARCHAR(32), salary NUMERIC(20, 2), bonus FLOAT)")
7272
.executeUpdate()
73+
connection.prepareStatement(
74+
s"""CREATE TABLE pattern_testing_table (
75+
|pattern_testing_col VARCHAR(50)
76+
|)
77+
""".stripMargin
78+
).executeUpdate()
7379
}
7480

7581
override def notSupportsTableComment: Boolean = true

connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
7373
connection.prepareStatement(
7474
"CREATE TABLE employee (dept INT, name VARCHAR(32), salary DECIMAL(20, 2)," +
7575
" bonus DOUBLE)").executeUpdate()
76+
connection.prepareStatement(
77+
s"""CREATE TABLE pattern_testing_table (
78+
|pattern_testing_col LONGTEXT
79+
|)
80+
""".stripMargin
81+
).executeUpdate()
7682
}
7783

7884
override def testUpdateColumnType(tbl: String): Unit = {

connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes
9393
connection.prepareStatement(
9494
"CREATE TABLE employee (dept NUMBER(32), name VARCHAR2(32), salary NUMBER(20, 2)," +
9595
" bonus BINARY_DOUBLE)").executeUpdate()
96+
connection.prepareStatement(
97+
s"""CREATE TABLE pattern_testing_table (
98+
|pattern_testing_col VARCHAR(50)
99+
|)
100+
""".stripMargin
101+
).executeUpdate()
96102
}
97103

98104
override def testUpdateColumnType(tbl: String): Unit = {

connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT
5959
connection.prepareStatement(
6060
"CREATE TABLE employee (dept INTEGER, name VARCHAR(32), salary NUMERIC(20, 2)," +
6161
" bonus double precision)").executeUpdate()
62+
connection.prepareStatement(
63+
s"""CREATE TABLE pattern_testing_table (
64+
|pattern_testing_col VARCHAR(50)
65+
|)
66+
""".stripMargin
67+
).executeUpdate()
6268
}
6369

6470
override def testUpdateColumnType(tbl: String): Unit = {

connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,235 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
359359
assert(scan.schema.names.sameElements(Seq(col)))
360360
}
361361

362+
test("SPARK-48172: Test CONTAINS") {
363+
val df1 = spark.sql(
364+
s"""
365+
|SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
366+
|WHERE contains(pattern_testing_col, 'quote\\'')""".stripMargin)
367+
df1.explain("formatted")
368+
val rows1 = df1.collect()
369+
assert(rows1.length === 1)
370+
assert(rows1(0).getString(0) === "special_character_quote'_present")
371+
372+
val df2 = spark.sql(
373+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
374+
|WHERE contains(pattern_testing_col, 'percent%')""".stripMargin)
375+
val rows2 = df2.collect()
376+
assert(rows2.length === 1)
377+
assert(rows2(0).getString(0) === "special_character_percent%_present")
378+
379+
val df3 = spark.
380+
sql(
381+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
382+
|WHERE contains(pattern_testing_col, 'underscore_')""".stripMargin)
383+
val rows3 = df3.collect()
384+
assert(rows3.length === 1)
385+
assert(rows3(0).getString(0) === "special_character_underscore_present")
386+
387+
val df4 = spark.
388+
sql(
389+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
390+
|WHERE contains(pattern_testing_col, 'character')
391+
|ORDER BY pattern_testing_col""".stripMargin)
392+
val rows4 = df4.collect()
393+
assert(rows4.length === 6)
394+
assert(rows4(0).getString(0) === "special_character_percent%_present")
395+
assert(rows4(1).getString(0) === "special_character_percent_not_present")
396+
assert(rows4(2).getString(0) === "special_character_quote'_present")
397+
assert(rows4(3).getString(0) === "special_character_quote_not_present")
398+
assert(rows4(4).getString(0) === "special_character_underscore_present")
399+
assert(rows4(5).getString(0) === "special_character_underscorenot_present")
400+
}
401+
402+
test("SPARK-48172: Test ENDSWITH") {
403+
val df1 = spark.sql(
404+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
405+
|WHERE endswith(pattern_testing_col, 'quote\\'_present')""".stripMargin)
406+
val rows1 = df1.collect()
407+
assert(rows1.length === 1)
408+
assert(rows1(0).getString(0) === "special_character_quote'_present")
409+
410+
val df2 = spark.sql(
411+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
412+
|WHERE endswith(pattern_testing_col, 'percent%_present')""".stripMargin)
413+
val rows2 = df2.collect()
414+
assert(rows2.length === 1)
415+
assert(rows2(0).getString(0) === "special_character_percent%_present")
416+
417+
val df3 = spark.
418+
sql(
419+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
420+
|WHERE endswith(pattern_testing_col, 'underscore_present')""".stripMargin)
421+
val rows3 = df3.collect()
422+
assert(rows3.length === 1)
423+
assert(rows3(0).getString(0) === "special_character_underscore_present")
424+
425+
val df4 = spark.
426+
sql(
427+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
428+
|WHERE endswith(pattern_testing_col, 'present')
429+
|ORDER BY pattern_testing_col""".stripMargin)
430+
val rows4 = df4.collect()
431+
assert(rows4.length === 6)
432+
assert(rows4(0).getString(0) === "special_character_percent%_present")
433+
assert(rows4(1).getString(0) === "special_character_percent_not_present")
434+
assert(rows4(2).getString(0) === "special_character_quote'_present")
435+
assert(rows4(3).getString(0) === "special_character_quote_not_present")
436+
assert(rows4(4).getString(0) === "special_character_underscore_present")
437+
assert(rows4(5).getString(0) === "special_character_underscorenot_present")
438+
}
439+
440+
test("SPARK-48172: Test STARTSWITH") {
441+
val df1 = spark.sql(
442+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
443+
|WHERE startswith(pattern_testing_col, 'special_character_quote\\'')""".stripMargin)
444+
val rows1 = df1.collect()
445+
assert(rows1.length === 1)
446+
assert(rows1(0).getString(0) === "special_character_quote'_present")
447+
448+
val df2 = spark.sql(
449+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
450+
|WHERE startswith(pattern_testing_col, 'special_character_percent%')""".stripMargin)
451+
val rows2 = df2.collect()
452+
assert(rows2.length === 1)
453+
assert(rows2(0).getString(0) === "special_character_percent%_present")
454+
455+
val df3 = spark.
456+
sql(
457+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
458+
|WHERE startswith(pattern_testing_col, 'special_character_underscore_')""".stripMargin)
459+
val rows3 = df3.collect()
460+
assert(rows3.length === 1)
461+
assert(rows3(0).getString(0) === "special_character_underscore_present")
462+
463+
val df4 = spark.
464+
sql(
465+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
466+
|WHERE startswith(pattern_testing_col, 'special_character')
467+
|ORDER BY pattern_testing_col""".stripMargin)
468+
val rows4 = df4.collect()
469+
assert(rows4.length === 6)
470+
assert(rows4(0).getString(0) === "special_character_percent%_present")
471+
assert(rows4(1).getString(0) === "special_character_percent_not_present")
472+
assert(rows4(2).getString(0) === "special_character_quote'_present")
473+
assert(rows4(3).getString(0) === "special_character_quote_not_present")
474+
assert(rows4(4).getString(0) === "special_character_underscore_present")
475+
assert(rows4(5).getString(0) === "special_character_underscorenot_present")
476+
}
477+
478+
test("SPARK-48172: Test LIKE") {
479+
// this one should map to contains
480+
val df1 = spark.sql(
481+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
482+
|WHERE pattern_testing_col LIKE '%quote\\'%'""".stripMargin)
483+
val rows1 = df1.collect()
484+
assert(rows1.length === 1)
485+
assert(rows1(0).getString(0) === "special_character_quote'_present")
486+
487+
val df2 = spark.sql(
488+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
489+
|WHERE pattern_testing_col LIKE '%percent\\%%'""".stripMargin)
490+
val rows2 = df2.collect()
491+
assert(rows2.length === 1)
492+
assert(rows2(0).getString(0) === "special_character_percent%_present")
493+
494+
val df3 = spark.
495+
sql(
496+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
497+
|WHERE pattern_testing_col LIKE '%underscore\\_%'""".stripMargin)
498+
val rows3 = df3.collect()
499+
assert(rows3.length === 1)
500+
assert(rows3(0).getString(0) === "special_character_underscore_present")
501+
502+
val df4 = spark.
503+
sql(
504+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
505+
|WHERE pattern_testing_col LIKE '%character%'
506+
|ORDER BY pattern_testing_col""".stripMargin)
507+
val rows4 = df4.collect()
508+
assert(rows4.length === 6)
509+
assert(rows4(0).getString(0) === "special_character_percent%_present")
510+
assert(rows4(1).getString(0) === "special_character_percent_not_present")
511+
assert(rows4(2).getString(0) === "special_character_quote'_present")
512+
assert(rows4(3).getString(0) === "special_character_quote_not_present")
513+
assert(rows4(4).getString(0) === "special_character_underscore_present")
514+
assert(rows4(5).getString(0) === "special_character_underscorenot_present")
515+
516+
// map to startsWith
517+
// this one should map to contains
518+
val df5 = spark.sql(
519+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
520+
|WHERE pattern_testing_col LIKE 'special_character_quote\\'%'""".stripMargin)
521+
val rows5 = df5.collect()
522+
assert(rows5.length === 1)
523+
assert(rows5(0).getString(0) === "special_character_quote'_present")
524+
525+
val df6 = spark.sql(
526+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
527+
|WHERE pattern_testing_col LIKE 'special_character_percent\\%%'""".stripMargin)
528+
val rows6 = df6.collect()
529+
assert(rows6.length === 1)
530+
assert(rows6(0).getString(0) === "special_character_percent%_present")
531+
532+
val df7 = spark.
533+
sql(
534+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
535+
|WHERE pattern_testing_col LIKE 'special_character_underscore\\_%'""".stripMargin)
536+
val rows7 = df7.collect()
537+
assert(rows7.length === 1)
538+
assert(rows7(0).getString(0) === "special_character_underscore_present")
539+
540+
val df8 = spark.
541+
sql(
542+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
543+
|WHERE pattern_testing_col LIKE 'special_character%'
544+
|ORDER BY pattern_testing_col""".stripMargin)
545+
val rows8 = df8.collect()
546+
assert(rows8.length === 6)
547+
assert(rows8(0).getString(0) === "special_character_percent%_present")
548+
assert(rows8(1).getString(0) === "special_character_percent_not_present")
549+
assert(rows8(2).getString(0) === "special_character_quote'_present")
550+
assert(rows8(3).getString(0) === "special_character_quote_not_present")
551+
assert(rows8(4).getString(0) === "special_character_underscore_present")
552+
assert(rows8(5).getString(0) === "special_character_underscorenot_present")
553+
// map to endsWith
554+
// this one should map to contains
555+
val df9 = spark.sql(
556+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
557+
|WHERE pattern_testing_col LIKE '%quote\\'_present'""".stripMargin)
558+
val rows9 = df9.collect()
559+
assert(rows9.length === 1)
560+
assert(rows9(0).getString(0) === "special_character_quote'_present")
561+
562+
val df10 = spark.sql(
563+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
564+
|WHERE pattern_testing_col LIKE '%percent\\%_present'""".stripMargin)
565+
val rows10 = df10.collect()
566+
assert(rows10.length === 1)
567+
assert(rows10(0).getString(0) === "special_character_percent%_present")
568+
569+
val df11 = spark.
570+
sql(
571+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
572+
|WHERE pattern_testing_col LIKE '%underscore\\_present'""".stripMargin)
573+
val rows11 = df11.collect()
574+
assert(rows11.length === 1)
575+
assert(rows11(0).getString(0) === "special_character_underscore_present")
576+
577+
val df12 = spark.
578+
sql(
579+
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
580+
|WHERE pattern_testing_col LIKE '%present' ORDER BY pattern_testing_col""".stripMargin)
581+
val rows12 = df12.collect()
582+
assert(rows12.length === 6)
583+
assert(rows12(0).getString(0) === "special_character_percent%_present")
584+
assert(rows12(1).getString(0) === "special_character_percent_not_present")
585+
assert(rows12(2).getString(0) === "special_character_quote'_present")
586+
assert(rows12(3).getString(0) === "special_character_quote_not_present")
587+
assert(rows12(4).getString(0) === "special_character_underscore_present")
588+
assert(rows12(5).getString(0) === "special_character_underscorenot_present")
589+
}
590+
362591
test("SPARK-37038: Test TABLESAMPLE") {
363592
if (supportsTableSample) {
364593
withTable(s"$catalogName.new_table") {

sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ protected String escapeSpecialCharsForLikePattern(String str) {
6565
switch (c) {
6666
case '_' -> builder.append("\\_");
6767
case '%' -> builder.append("\\%");
68-
case '\'' -> builder.append("\\\'");
6968
default -> builder.append(c);
7069
}
7170
}

sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.connector.expressions
1919

20+
import org.apache.commons.lang3.StringUtils
21+
2022
import org.apache.spark.SparkException
2123
import org.apache.spark.sql.catalyst
2224
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
@@ -388,7 +390,7 @@ private[sql] object HoursTransform {
388390
private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] {
389391
override def toString: String = {
390392
if (dataType.isInstanceOf[StringType]) {
391-
s"'$value'"
393+
s"'${StringUtils.replace(s"$value", "'", "''")}'"
392394
} else {
393395
s"$value"
394396
}

sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -259,13 +259,6 @@ private[sql] case class H2Dialect() extends JdbcDialect {
259259
}
260260

261261
class H2SQLBuilder extends JDBCSQLBuilder {
262-
override def escapeSpecialCharsForLikePattern(str: String): String = {
263-
str.map {
264-
case '_' => "\\_"
265-
case '%' => "\\%"
266-
case c => c.toString
267-
}.mkString
268-
}
269262

270263
override def visitAggregateFunction(
271264
funcName: String, isDistinct: Boolean, inputs: Array[String]): String =

0 commit comments

Comments
 (0)