Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,37 @@ private[sql] object JDBCRDD extends Logging {
* @return A Catalyst schema corresponding to columns in the given order.
*/
private def pruneSchema(schema: StructType, columns: Array[String]): StructType = {
val fieldMap = Map(schema.fields map { x => x.metadata.getString("name") -> x }: _*)
new StructType(columns map { name => fieldMap(name) })
val fieldMap = Map(schema.fields.map(x => x.metadata.getString("name") -> x): _*)
new StructType(columns.map(name => fieldMap(name)))
}

/**
* Converts value to SQL expression.
*/
private def compileValue(value: Any): Any = value match {
case stringValue: String => s"'${escapeSql(stringValue)}'"
case timestampValue: Timestamp => "'" + timestampValue + "'"
case dateValue: Date => "'" + dateValue + "'"
case _ => value
}

private def escapeSql(value: String): String =
if (value == null) null else StringUtils.replace(value, "'", "''")

/**
* Turns a single Filter into a String representing a SQL expression.
* Returns null for an unhandled filter.
*/
private def compileFilter(f: Filter): String = f match {
case EqualTo(attr, value) => s"$attr = ${compileValue(value)}"
case Not(EqualTo(attr, value)) => s"$attr != ${compileValue(value)}"
case LessThan(attr, value) => s"$attr < ${compileValue(value)}"
case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}"
case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}"
case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}"
case IsNull(attr) => s"$attr IS NULL"
case IsNotNull(attr) => s"$attr IS NOT NULL"
case _ => null
}


Expand Down Expand Up @@ -240,37 +269,12 @@ private[sql] class JDBCRDD(
if (sb.length == 0) "1" else sb.substring(1)
}

/**
* Converts value to SQL expression.
*/
private def compileValue(value: Any): Any = value match {
case stringValue: String => s"'${escapeSql(stringValue)}'"
case timestampValue: Timestamp => "'" + timestampValue + "'"
case dateValue: Date => "'" + dateValue + "'"
case _ => value
}

private def escapeSql(value: String): String =
if (value == null) null else StringUtils.replace(value, "'", "''")

/**
* Turns a single Filter into a String representing a SQL expression.
* Returns null for an unhandled filter.
*/
private def compileFilter(f: Filter): String = f match {
case EqualTo(attr, value) => s"$attr = ${compileValue(value)}"
case LessThan(attr, value) => s"$attr < ${compileValue(value)}"
case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}"
case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}"
case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}"
case _ => null
}

/**
* `filters`, but as a WHERE clause suitable for injection into a SQL query.
*/
private val filterWhereClause: String = {
val filterStrings = filters map compileFilter filter (_ != null)
val filterStrings = filters.map(JDBCRDD.compileFilter).filter(_ != null)
if (filterStrings.size > 0) {
val sb = new StringBuilder("WHERE ")
filterStrings.foreach(x => sb.append(x).append(" AND "))
Expand Down
24 changes: 22 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,22 @@
package org.apache.spark.sql.jdbc

import java.math.BigDecimal
import java.sql.DriverManager
import java.sql.{Date, DriverManager, Timestamp}
import java.util.{Calendar, GregorianCalendar, Properties}

import org.h2.jdbc.JdbcSQLException
import org.scalatest.BeforeAndAfter
import org.scalatest.PrivateMethodTester

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.sources._
import org.apache.spark.util.Utils

class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
class JDBCSuite extends SparkFunSuite
with BeforeAndAfter with PrivateMethodTester with SharedSQLContext {
import testImplicits._

val url = "jdbc:h2:mem:testdb0"
Expand Down Expand Up @@ -427,6 +431,22 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext
assert(DerbyColumns === Seq(""""abc"""", """"key""""))
}

test("compile filters") {
val compileFilter = PrivateMethod[String]('compileFilter)
def doCompileFilter(f: Filter): String = JDBCRDD invokePrivate compileFilter(f)
assert(doCompileFilter(EqualTo("col0", 3)) === "col0 = 3")
assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === "col1 != 'abc'")
assert(doCompileFilter(LessThan("col0", 5)) === "col0 < 5")
assert(doCompileFilter(LessThan("col3",
Timestamp.valueOf("1995-11-21 00:00:00.0"))) === "col3 < '1995-11-21 00:00:00.0'")
assert(doCompileFilter(LessThan("col4", Date.valueOf("1983-08-04"))) === "col4 < '1983-08-04'")
assert(doCompileFilter(LessThanOrEqual("col0", 5)) === "col0 <= 5")
assert(doCompileFilter(GreaterThan("col0", 3)) === "col0 > 3")
assert(doCompileFilter(GreaterThanOrEqual("col0", 3)) === "col0 >= 3")
assert(doCompileFilter(IsNull("col1")) === "col1 IS NULL")
assert(doCompileFilter(IsNotNull("col1")) === "col1 IS NOT NULL")
}

test("Dialect unregister") {
JdbcDialects.registerDialect(testH2Dialect)
JdbcDialects.unregisterDialect(testH2Dialect)
Expand Down