diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 7072ee4b4e3b..c74574d280a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -179,7 +179,7 @@ private[sql] object JDBCRDD extends Logging { case stringValue: String => s"'${escapeSql(stringValue)}'" case timestampValue: Timestamp => "'" + timestampValue + "'" case dateValue: Date => "'" + dateValue + "'" - case arrayValue: Array[Object] => arrayValue.map(compileValue).mkString(", ") + case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ") case _ => value } @@ -188,24 +188,41 @@ private[sql] object JDBCRDD extends Logging { /** * Turns a single Filter into a String representing a SQL expression. - * Returns null for an unhandled filter. + * Returns None for an unhandled filter. */ - private def compileFilter(f: Filter): String = f match { - case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" - case Not(f) => s"(NOT (${compileFilter(f)}))" - case LessThan(attr, value) => s"$attr < ${compileValue(value)}" - case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" - case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" - case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" - case StringStartsWith(attr, value) => s"${attr} LIKE '${value}%'" - case StringEndsWith(attr, value) => s"${attr} LIKE '%${value}'" - case StringContains(attr, value) => s"${attr} LIKE '%${value}%'" - case IsNull(attr) => s"$attr IS NULL" - case IsNotNull(attr) => s"$attr IS NOT NULL" - case In(attr, value) => s"$attr IN (${compileValue(value)})" - case Or(f1, f2) => s"(${compileFilter(f1)}) OR (${compileFilter(f2)})" - case And(f1, f2) => s"(${compileFilter(f1)}) AND (${compileFilter(f2)})" - case _ => null + private def compileFilter(f: Filter): Option[String] = { + Option(f match { + case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" + case LessThan(attr, value) => s"$attr < ${compileValue(value)}" + case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" + case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" + case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" + case IsNull(attr) => s"$attr IS NULL" + case IsNotNull(attr) => s"$attr IS NOT NULL" + case StringStartsWith(attr, value) => s"${attr} LIKE '${value}%'" + case StringEndsWith(attr, value) => s"${attr} LIKE '%${value}'" + case StringContains(attr, value) => s"${attr} LIKE '%${value}%'" + case In(attr, value) => s"$attr IN (${compileValue(value)})" + case Not(f) => compileFilter(f).map(p => s"(NOT ($p))").getOrElse(null) + case Or(f1, f2) => + // We can't compile Or filter unless both sub-filters are compiled successfully. + // It applies too for the following And filter. + // If we can make sure compileFilter supports all filters, we can remove this check. + val or = Seq(f1, f2).map(compileFilter(_)).flatten + if (or.size == 2) { + or.map(p => s"($p)").mkString(" OR ") + } else { + null + } + case And(f1, f2) => + val and = Seq(f1, f2).map(compileFilter(_)).flatten + if (and.size == 2) { + and.map(p => s"($p)").mkString(" AND ") + } else { + null + } + case _ => null + }) } /** @@ -307,25 +324,21 @@ private[sql] class JDBCRDD( /** * `filters`, but as a WHERE clause suitable for injection into a SQL query. */ - private val filterWhereClause: String = { - val filterStrings = filters.map(JDBCRDD.compileFilter).filter(_ != null) - if (filterStrings.size > 0) { - val sb = new StringBuilder("WHERE ") - filterStrings.foreach(x => sb.append(x).append(" AND ")) - sb.substring(0, sb.length - 5) - } else "" - } + private val filterWhereClause: String = + filters.map(JDBCRDD.compileFilter).flatten.mkString(" AND ") /** * A WHERE clause representing both `filters`, if any, and the current partition. */ private def getWhereClause(part: JDBCPartition): String = { if (part.whereClause != null && filterWhereClause.length > 0) { - filterWhereClause + " AND " + part.whereClause + "WHERE " + filterWhereClause + " AND " + part.whereClause } else if (part.whereClause != null) { "WHERE " + part.whereClause + } else if (filterWhereClause.length > 0) { + "WHERE " + filterWhereClause } else { - filterWhereClause + "" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 00e37f107a88..633ae215fc12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -190,7 +190,7 @@ class JDBCSuite extends SparkFunSuite assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME IN ('mary', 'fred')")) .collect().size == 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME NOT IN ('fred')")) - .collect().size === 2) + .collect().size == 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1 OR NAME = 'mary'")) .collect().size == 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1 OR NAME = 'mary' " @@ -453,8 +453,8 @@ class JDBCSuite extends SparkFunSuite } test("compile filters") { - val compileFilter = PrivateMethod[String]('compileFilter) - def doCompileFilter(f: Filter): String = JDBCRDD invokePrivate compileFilter(f) + val compileFilter = PrivateMethod[Option[String]]('compileFilter) + def doCompileFilter(f: Filter): String = JDBCRDD invokePrivate compileFilter(f) getOrElse("") assert(doCompileFilter(EqualTo("col0", 3)) === "col0 = 3") assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === "(NOT (col1 = 'abc'))") assert(doCompileFilter(And(EqualTo("col0", 0), EqualTo("col1", "def"))) @@ -473,6 +473,7 @@ class JDBCSuite extends SparkFunSuite === "(NOT (col1 IN ('mno', 'pqr')))") assert(doCompileFilter(IsNull("col1")) === "col1 IS NULL") assert(doCompileFilter(IsNotNull("col1")) === "col1 IS NOT NULL") + assert(doCompileFilter(And(EqualNullSafe("col0", "abc"), EqualTo("col1", "def"))) === "") } test("Dialect unregister") {