Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ private[sql] object JDBCRDD extends Logging {
* Turns a single Filter into a String representing a SQL expression.
* Returns None for an unhandled filter.
*/
private def compileFilter(f: Filter): Option[String] = {
private[jdbc] def compileFilter(f: Filter): Option[String] = {
Option(f match {
case EqualTo(attr, value) => s"$attr = ${compileValue(value)}"
case EqualNullSafe(attr, value) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ private[sql] case class JDBCRelation(

override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)

// Check if JDBCRDD.compileFilter can accept input filters
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
filters.filter(JDBCRDD.compileFilter(_).isEmpty)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the update!

When we were using == null (I guess all filters were marked as unhandled, right?), all tests still passed. So, I am wondering if existing tests are sufficient?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added tests though, ISTM that JDBCRDD.compileFilter does not return None for given Filter because the function can compile all the Filter implemented in sql.sources.filters.

}

override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
Expand Down
70 changes: 50 additions & 20 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ import java.sql.{Date, DriverManager, Timestamp}
import java.util.{Calendar, GregorianCalendar, Properties}

import org.h2.jdbc.JdbcSQLException
import org.scalatest.BeforeAndAfter
import org.scalatest.PrivateMethodTester
import org.scalatest.{BeforeAndAfter, PrivateMethodTester}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.ExplainCommand
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.execution.PhysicalRDD
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD
import org.apache.spark.sql.sources._
Expand Down Expand Up @@ -183,33 +183,63 @@ class JDBCSuite extends SparkFunSuite
}

test("SELECT * WHERE (simple predicates)") {
assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0)
assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size == 2)
assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size == 1)
assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size == 1)
assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME <=> 'fred'")).collect().size == 1)
assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size == 2)
assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size == 2)
assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME IN ('mary', 'fred')"))
def checkPushdown(df: DataFrame): DataFrame = {
val parentPlan = df.queryExecution.executedPlan
// Check if SparkPlan Filter is removed in a physical plan and
// the plan only has PhysicalRDD to scan JDBCRelation.
assert(parentPlan.isInstanceOf[PhysicalRDD])
assert(parentPlan.asInstanceOf[PhysicalRDD].nodeName.contains("JDBCRelation"))
df
}
assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0)
assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size == 2)
assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size == 1)
assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size == 1)
assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME <=> 'fred'")).collect().size == 1)
assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size == 2)
assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size == 2)
assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME IN ('mary', 'fred')"))
.collect().size == 2)
assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME NOT IN ('fred')"))
assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME NOT IN ('fred')"))
.collect().size == 2)
assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1 OR NAME = 'mary'"))
assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID = 1 OR NAME = 'mary'"))
.collect().size == 2)
assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1 OR NAME = 'mary' "
assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID = 1 OR NAME = 'mary' "
+ "AND THEID = 2")).collect().size == 2)
assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME LIKE 'fr%'")).collect().size == 1)
assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME LIKE '%ed'")).collect().size == 1)
assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME LIKE '%re%'")).collect().size == 1)
assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size == 1)
assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NOT NULL")).collect().size == 0)
assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME LIKE 'fr%'")).collect().size == 1)
assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME LIKE '%ed'")).collect().size == 1)
assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME LIKE '%re%'")).collect().size == 1)
assert(checkPushdown(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size == 1)
assert(checkPushdown(sql("SELECT * FROM nulltypes WHERE A IS NOT NULL")).collect().size == 0)

// This is a test to reflect discussion in SPARK-12218.
// The older versions of spark have this kind of bugs in parquet data source.
val df1 = sql("SELECT * FROM foobar WHERE NOT (THEID != 2 AND NAME != 'mary')")
val df2 = sql("SELECT * FROM foobar WHERE NOT (THEID != 2) OR NOT (NAME != 'mary')")
assert(df1.collect.toSet === Set(Row("mary", 2)))
assert(df2.collect.toSet === Set(Row("mary", 2)))

def checkNotPushdown(df: DataFrame): DataFrame = {
val parentPlan = df.queryExecution.executedPlan
// Check if SparkPlan Filter is not removed in a physical plan because JDBCRDD
// cannot compile given predicates.
assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen])
val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]
assert(node.plan.isInstanceOf[org.apache.spark.sql.execution.Filter])
df
}
assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 1) < 2")).collect().size == 0)
assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 2) != 4")).collect().size == 2)
}

test("SELECT COUNT(1) WHERE (predicates)") {
// Check if an answer is correct when Filter is removed from operations such as count() which
// does not require any columns. In some data sources, e.g., Parquet, `requiredColumns` in
// org.apache.spark.sql.sources.interfaces is not given in logical plans, but some filters
// are applied for columns with Filter producing wrong results. On the other hand, JDBCRDD
// correctly handles this case by assigning `requiredColumns` properly. See PR 10427 for more
// discussions.
assert(sql("SELECT COUNT(1) FROM foobar WHERE NAME = 'mary'").collect.toSet === Set(Row(1)))
}

test("SELECT * WHERE (quoted strings)") {
Expand Down