Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,49 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
val types = rows(0).toSeq.map(x => x.getClass.toString)
assert(types(1).equals("class java.sql.Timestamp"))
}

test("SPARK-18004: Make sure date or timestamp related predicate is pushed down correctly") {
val props = new Properties()
props.put("oracle.jdbc.mapDateToTimestamp", "false")

val schema = StructType(Seq(
StructField("date_type", DateType, true),
StructField("timestamp_type", TimestampType, true)
))

val tableName = "test_date_timestamp_pushdown"
val dateVal = Date.valueOf("2017-06-22")
val timestampVal = Timestamp.valueOf("2017-06-22 21:30:07")

val data = spark.sparkContext.parallelize(Seq(
Row(dateVal, timestampVal)
))

val dfWrite = spark.createDataFrame(data, schema)
dfWrite.write.jdbc(jdbcUrl, tableName, props)

val dfRead = spark.read.jdbc(jdbcUrl, tableName, props)

val millis = System.currentTimeMillis()
val dt = new java.sql.Date(millis)
val ts = new java.sql.Timestamp(millis)

// Query Oracle table with date and timestamp predicates
// which should be pushed down to Oracle.
val df = dfRead.filter(dfRead.col("date_type").lt(dt))
.filter(dfRead.col("timestamp_type").lt(ts))

val metadata = df.queryExecution.sparkPlan.metadata
// The "PushedFilters" part should be exist in Datafrome's
// physical plan and the existence of right literals in
// "PushedFilters" is used to prove that the predicates
// pushing down have been effective.
assert(metadata.get("PushedFilters").ne(None))
assert(metadata("PushedFilters").contains(dt.toString))
assert(metadata("PushedFilters").contains(ts.toString))

val row = df.collect()(0)
assert(row.getDate(0).equals(dateVal))
assert(row.getTimestamp(1).equals(timestampVal))
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to confirm whether you run the docker test? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So sorry to reply so late. I added the docker test to confirm the date/timestamp-related predicates can be pushed down to Oracle and executed correctly. Before providing this patch , this kind of predicates can be pushed down to Oracle but can not be executed.

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@

package org.apache.spark.sql.execution.datasources.jdbc

import java.sql.{Connection, Date, PreparedStatement, ResultSet, SQLException, Timestamp}
import java.sql.{Connection, PreparedStatement, ResultSet, SQLException}

import scala.util.control.NonFatal

import org.apache.commons.lang3.StringUtils

import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -86,20 +84,6 @@ object JDBCRDD extends Logging {
new StructType(columns.map(name => fieldMap(name)))
}

/**
* Converts value to SQL expression.
*/
private def compileValue(value: Any): Any = value match {
case stringValue: String => s"'${escapeSql(stringValue)}'"
case timestampValue: Timestamp => "'" + timestampValue + "'"
case dateValue: Date => "'" + dateValue + "'"
case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ")
case _ => value
}

private def escapeSql(value: String): String =
if (value == null) null else StringUtils.replace(value, "'", "''")

/**
* Turns a single Filter into a String representing a SQL expression.
* Returns None for an unhandled filter.
Expand All @@ -108,23 +92,24 @@ object JDBCRDD extends Logging {
def quote(colName: String): String = dialect.quoteIdentifier(colName)

Option(f match {
case EqualTo(attr, value) => s"${quote(attr)} = ${compileValue(value)}"
case EqualTo(attr, value) => s"${quote(attr)} = ${dialect.compileValue(value)}"
case EqualNullSafe(attr, value) =>
val col = quote(attr)
s"(NOT ($col != ${compileValue(value)} OR $col IS NULL OR " +
s"${compileValue(value)} IS NULL) OR ($col IS NULL AND ${compileValue(value)} IS NULL))"
case LessThan(attr, value) => s"${quote(attr)} < ${compileValue(value)}"
case GreaterThan(attr, value) => s"${quote(attr)} > ${compileValue(value)}"
case LessThanOrEqual(attr, value) => s"${quote(attr)} <= ${compileValue(value)}"
case GreaterThanOrEqual(attr, value) => s"${quote(attr)} >= ${compileValue(value)}"
s"(NOT ($col != ${dialect.compileValue(value)} OR $col IS NULL OR " +
s"${dialect.compileValue(value)} IS NULL) OR " +
s"($col IS NULL AND ${dialect.compileValue(value)} IS NULL))"
case LessThan(attr, value) => s"${quote(attr)} < ${dialect.compileValue(value)}"
case GreaterThan(attr, value) => s"${quote(attr)} > ${dialect.compileValue(value)}"
case LessThanOrEqual(attr, value) => s"${quote(attr)} <= ${dialect.compileValue(value)}"
case GreaterThanOrEqual(attr, value) => s"${quote(attr)} >= ${dialect.compileValue(value)}"
case IsNull(attr) => s"${quote(attr)} IS NULL"
case IsNotNull(attr) => s"${quote(attr)} IS NOT NULL"
case StringStartsWith(attr, value) => s"${quote(attr)} LIKE '${value}%'"
case StringEndsWith(attr, value) => s"${quote(attr)} LIKE '%${value}'"
case StringContains(attr, value) => s"${quote(attr)} LIKE '%${value}%'"
case In(attr, value) if value.isEmpty =>
s"CASE WHEN ${quote(attr)} IS NULL THEN NULL ELSE FALSE END"
case In(attr, value) => s"${quote(attr)} IN (${compileValue(value)})"
case In(attr, value) => s"${quote(attr)} IN (${dialect.compileValue(value)})"
case Not(f) => compileFilter(f, dialect).map(p => s"(NOT ($p))").getOrElse(null)
case Or(f1, f2) =>
// We can't compile Or filter unless both sub-filters are compiled successfully.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.apache.spark.sql.jdbc

import java.sql.Connection
import java.sql.{Connection, Date, Timestamp}

import org.apache.commons.lang3.StringUtils

import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -123,6 +125,29 @@ abstract class JdbcDialect extends Serializable {
def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = {
}

/**
* Escape special characters in SQL string literals.
* @param value The string to be escaped.
* @return Escaped string.
*/
@Since("2.3.0")
protected[jdbc] def escapeSql(value: String): String =
if (value == null) null else StringUtils.replace(value, "'", "''")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far, it only covers single quotes. Actually, this is not complete. We need to improve it in the future PRs. If you have time, feel free to submit PRs.

Copy link
Member

@gatorsmile gatorsmile Jul 2, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def escapeSql(value: String): String = {
  if (value == null) null else StringUtils.replace(value, "'", "''")}
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am very glad to be involved!


/**
* Converts value to SQL expression.
* @param value The value to be converted.
* @return Converted value.
*/
@Since("2.3.0")
def compileValue(value: Any): Any = value match {
case stringValue: String => s"'${escapeSql(stringValue)}'"
case timestampValue: Timestamp => "'" + timestampValue + "'"
case dateValue: Date => "'" + dateValue + "'"
case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ")
case _ => value
}

/**
* Return Some[true] iff `TRUNCATE TABLE` causes cascading default.
* Some[true] : TRUNCATE TABLE causes cascading.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.jdbc

import java.sql.Types
import java.sql.{Date, Timestamp, Types}

import org.apache.spark.sql.types._

Expand Down Expand Up @@ -68,5 +68,18 @@ private case object OracleDialect extends JdbcDialect {
case _ => None
}

override def compileValue(value: Any): Any = value match {
// The JDBC drivers support date literals in SQL statements written in the
// format: {d 'yyyy-mm-dd'} and timestamp literals in SQL statements written
// in the format: {ts 'yyyy-mm-dd hh:mm:ss.f...'}. For details, see
// 'Oracle Database JDBC Developer’s Guide and Reference, 11g Release 1 (11.1)'
// Appendix A Reference Information.
case stringValue: String => s"'${escapeSql(stringValue)}'"
case timestampValue: Timestamp => "{ts '" + timestampValue + "'}"
case dateValue: Date => "{d '" + dateValue + "'}"
case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ")
case _ => value
}

override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
}