Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.jdbc
import java.sql.{Connection, DriverManager}
import java.util.{Locale, Properties}

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap

/**
Expand All @@ -33,6 +34,14 @@ class JDBCOptions(

def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters))

@DeveloperApi
def this(url: String) = {
this(CaseInsensitiveMap(Map(
JDBCOptions.JDBC_URL -> url,
JDBCOptions.JDBC_DRIVER_CLASS -> "org.h2.Driver",
JDBCOptions.JDBC_TABLE_NAME -> "")))
}

def this(url: String, table: String, parameters: Map[String, String]) = {
this(CaseInsensitiveMap(parameters ++ Map(
JDBCOptions.JDBC_URL -> url,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ object JDBCRDD extends Logging {
* @throws SQLException if the table contains an unsupported type.
*/
def resolveTable(options: JDBCOptions): StructType = {
val url = options.url
val table = options.table
val dialect = JdbcDialects.get(url)
val dialect = JdbcDialects.get(options)
val conn: Connection = JdbcUtils.createConnectionFactory(options)()
try {
val statement = conn.prepareStatement(dialect.getSchemaQuery(table))
Expand Down Expand Up @@ -167,8 +166,7 @@ object JDBCRDD extends Logging {
filters: Array[Filter],
parts: Array[Partition],
options: JDBCOptions): RDD[InternalRow] = {
val url = options.url
val dialect = JdbcDialects.get(url)
val dialect = JdbcDialects.get(options)
val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName))
new JDBCRDD(
sc,
Expand All @@ -177,7 +175,7 @@ object JDBCRDD extends Logging {
quotedColumns,
filters,
parts,
url,
options.url,
options)
}
}
Expand Down Expand Up @@ -217,7 +215,7 @@ private[jdbc] class JDBCRDD(
*/
private val filterWhereClause: String =
filters
.flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(url)))
.flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(options)))
.map(p => s"($p)").mkString(" AND ")

/**
Expand Down Expand Up @@ -284,7 +282,7 @@ private[jdbc] class JDBCRDD(
val inputMetrics = context.taskMetrics().inputMetrics
val part = thePart.asInstanceOf[JDBCPartition]
conn = getConnection()
val dialect = JdbcDialects.get(url)
val dialect = JdbcDialects.get(options)
import scala.collection.JavaConverters._
dialect.beforeFetch(conn, options.asConnectionProperties.asScala.toMap)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ private[sql] case class JDBCRelation(

// Check if JDBCRDD.compileFilter can accept input filters
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty)
filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions)).isEmpty)
}

override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
if (tableExists) {
mode match {
case SaveMode.Overwrite =>
if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {
if (options.isTruncate && isCascadingTruncateTable(options) == Some(false)) {
// In this case, we should truncate table and then load.
truncateTable(conn, options.table)
val tableSchema = JdbcUtils.getSchemaOption(conn, options)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ object JdbcUtils extends Logging {
* Returns true if the table already exists in the JDBC database.
*/
def tableExists(conn: Connection, options: JDBCOptions): Boolean = {
val dialect = JdbcDialects.get(options.url)
val dialect = JdbcDialects.get(options)

// Somewhat hacky, but there isn't a good way to identify whether a table exists for all
// SQL database systems using JDBC meta data calls, considering "table" could also include
Expand Down Expand Up @@ -105,8 +105,8 @@ object JdbcUtils extends Logging {
}
}

def isCascadingTruncateTable(url: String): Option[Boolean] = {
JdbcDialects.get(url).isCascadingTruncateTable()
def isCascadingTruncateTable(options: JDBCOptions): Option[Boolean] = {
JdbcDialects.get(options).isCascadingTruncateTable()
}

/**
Expand Down Expand Up @@ -247,7 +247,7 @@ object JdbcUtils extends Logging {
* Returns the schema if the table already exists in the JDBC database.
*/
def getSchemaOption(conn: Connection, options: JDBCOptions): Option[StructType] = {
val dialect = JdbcDialects.get(options.url)
val dialect = JdbcDialects.get(options)

try {
val statement = conn.prepareStatement(dialect.getSchemaQuery(options.table))
Expand Down Expand Up @@ -702,10 +702,10 @@ object JdbcUtils extends Logging {
*/
def schemaString(
df: DataFrame,
url: String,
options: JDBCOptions,
createTableColumnTypes: Option[String] = None): String = {
val sb = new StringBuilder()
val dialect = JdbcDialects.get(url)
val dialect = JdbcDialects.get(options)
val userSpecifiedColTypesMap = createTableColumnTypes
.map(parseUserSpecifiedCreateTableColumnTypes(df, _))
.getOrElse(Map.empty[String, String])
Expand Down Expand Up @@ -772,9 +772,8 @@ object JdbcUtils extends Logging {
tableSchema: Option[StructType],
isCaseSensitive: Boolean,
options: JDBCOptions): Unit = {
val url = options.url
val table = options.table
val dialect = JdbcDialects.get(url)
val dialect = JdbcDialects.get(options)
val rddSchema = df.schema
val getConnection: () => Connection = createConnectionFactory(options)
val batchSize = options.batchSize
Expand All @@ -801,7 +800,7 @@ object JdbcUtils extends Logging {
df: DataFrame,
options: JDBCOptions): Unit = {
val strSchema = schemaString(
df, options.url, options.createTableColumnTypes)
df, options, options.createTableColumnTypes)
val table = options.table
val createTableOptions = options.createTableOptions
// Create the table if the table does not exist.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.jdbc

import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.types.{DataType, MetadataBuilder}

/**
Expand All @@ -30,8 +31,8 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect

require(dialects.nonEmpty)

override def canHandle(url : String): Boolean =
dialects.map(_.canHandle(url)).reduce(_ && _)
override def canHandle(options: JDBCOptions): Boolean =
dialects.map(_.canHandle(options)).reduce(_ && _)

override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

package org.apache.spark.sql.jdbc

import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.types.{BooleanType, DataType, StringType}

private object DB2Dialect extends JdbcDialect {

override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2")
override def canHandle(options: JDBCOptions): Boolean = options.url.startsWith("jdbc:db2")

override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
case StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ package org.apache.spark.sql.jdbc

import java.sql.Types

import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.types._


private object DerbyDialect extends JdbcDialect {

override def canHandle(url: String): Boolean = url.startsWith("jdbc:derby")
override def canHandle(options: JDBCOptions): Boolean = options.url.startsWith("jdbc:derby")

override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc
import java.sql.Connection

import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since}
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -58,11 +59,11 @@ case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int)
abstract class JdbcDialect extends Serializable {
/**
* Check if this dialect instance can handle a certain jdbc url.
* @param url the jdbc url.
* @param options the jdbc options.
* @return True if the dialect can be applied on the given jdbc url.
* @throws NullPointerException if the url is null.
*/
def canHandle(url : String): Boolean
def canHandle(options: JDBCOptions): Boolean

/**
* Get the custom datatype mapping for the given jdbc meta information.
Expand Down Expand Up @@ -179,8 +180,8 @@ object JdbcDialects {
/**
* Fetch the JdbcDialect class corresponding to a given database url.
*/
def get(url: String): JdbcDialect = {
val matchingDialects = dialects.filter(_.canHandle(url))
def get(options: JDBCOptions): JdbcDialect = {
val matchingDialects = dialects.filter(_.canHandle(options))
matchingDialects.length match {
case 0 => NoopDialect
case 1 => matchingDialects.head
Expand All @@ -193,5 +194,5 @@ object JdbcDialects {
* NOOP dialect object, always returning the neutral element.
*/
private object NoopDialect extends JdbcDialect {
override def canHandle(url : String): Boolean = true
override def canHandle(options: JDBCOptions): Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

package org.apache.spark.sql.jdbc

import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.types._


private object MsSqlServerDialect extends JdbcDialect {

override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver")
override def canHandle(options: JDBCOptions): Boolean = options.url.startsWith("jdbc:sqlserver")

override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ package org.apache.spark.sql.jdbc

import java.sql.Types

import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.types.{BooleanType, DataType, LongType, MetadataBuilder}

private case object MySQLDialect extends JdbcDialect {

override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql")
override def canHandle(options: JDBCOptions): Boolean = options.url.startsWith("jdbc:mysql")

override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,25 @@ package org.apache.spark.sql.jdbc

import java.sql.Types

import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.types._


private case object OracleDialect extends JdbcDialect {

override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle")
private var isAutoConvertNumber2Boolean: Boolean = true

override def canHandle(options: JDBCOptions): Boolean = {
isAutoConvertNumber2Boolean =
options.asProperties.getProperty("autoConvertNumber2Boolean", true.toString).toBoolean
options.url.startsWith("jdbc:oracle")
}

override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.NUMERIC) {
val scale = if (null != md) md.build().getLong("scale") else 0L
val scale =
if (null != md && md.build().contains("scale")) md.build().getLong("scale") else 0L
size match {
// Handle NUMBER fields that have no precision/scale in special way
// because JDBC ResultSetMetaData converts this to 0 precision and -127 scale
Expand All @@ -43,7 +51,8 @@ private case object OracleDialect extends JdbcDialect {
// Not sure if there is a more robust way to identify the field as a float (or other
// numeric types that do not specify a scale.
case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10))
case 1 => Option(BooleanType)
case 1 if isAutoConvertNumber2Boolean => Option(BooleanType)
case 1 if !isAutoConvertNumber2Boolean => Option(IntegerType)
case 3 | 5 | 10 => Option(IntegerType)
case 19 if scale == 0L => Option(LongType)
case 19 if scale == 4L => Option(FloatType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.types._

private object PostgresDialect extends JdbcDialect {

override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql")
override def canHandle(options: JDBCOptions): Boolean = options.url.startsWith("jdbc:postgresql")

override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ package org.apache.spark.sql.jdbc

import java.sql.Types

import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.types._


private case object TeradataDialect extends JdbcDialect {

override def canHandle(url: String): Boolean = { url.startsWith("jdbc:teradata") }
override def canHandle(options: JDBCOptions): Boolean = {
options.url.startsWith("jdbc:teradata")
}

override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
case StringType => Some(JdbcType("VARCHAR(255)", java.sql.Types.VARCHAR))
Expand Down
Loading