Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions docs/sql-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -1895,9 +1895,7 @@ the Data Sources API. The following options are supported:
<tr>
<td><code>driver</code></td>
<td>
The class name of the JDBC driver needed to connect to this URL. This class will be loaded
on the master and workers before running an JDBC commands to allow the driver to
register itself with the JDBC subsystem.
The class name of the JDBC driver to use to connect to this URL.
</td>
</tr>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
}
// connectionProperties should override settings in extraOptions
props.putAll(connectionProperties)
val conn = JdbcUtils.createConnection(url, props)
val conn = JdbcUtils.createConnectionFactory(url, props)()

try {
var tableExists = JdbcUtils.tableExists(conn, url, table)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,12 @@ class DefaultSource extends RelationProvider with DataSourceRegister {
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
val url = parameters.getOrElse("url", sys.error("Option 'url' not specified"))
val driver = parameters.getOrElse("driver", null)
val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified"))
val partitionColumn = parameters.getOrElse("partitionColumn", null)
val lowerBound = parameters.getOrElse("lowerBound", null)
val upperBound = parameters.getOrElse("upperBound", null)
val numPartitions = parameters.getOrElse("numPartitions", null)

if (driver != null) DriverRegistry.register(driver)

if (partitionColumn != null
&& (lowerBound == null || upperBound == null || numPartitions == null)) {
sys.error("Partitioning incompletely specified")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,5 @@ object DriverRegistry extends Logging {
}
}
}

def getDriverClassName(url: String): String = DriverManager.getDriver(url) match {
case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName
case driver => driver.getClass.getCanonicalName
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.execution.datasources.jdbc

import java.sql.{Connection, Date, DriverManager, ResultSet, ResultSetMetaData, SQLException, Timestamp}
import java.sql.{Connection, Date, ResultSet, ResultSetMetaData, SQLException, Timestamp}
import java.util.Properties

import scala.util.control.NonFatal
Expand All @@ -41,7 +41,6 @@ private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Par
override def index: Int = idx
}


private[sql] object JDBCRDD extends Logging {

/**
Expand Down Expand Up @@ -120,7 +119,7 @@ private[sql] object JDBCRDD extends Logging {
*/
def resolveTable(url: String, table: String, properties: Properties): StructType = {
val dialect = JdbcDialects.get(url)
val conn: Connection = getConnector(properties.getProperty("driver"), url, properties)()
val conn: Connection = JdbcUtils.createConnectionFactory(url, properties)()
try {
val statement = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0")
try {
Expand Down Expand Up @@ -201,36 +200,13 @@ private[sql] object JDBCRDD extends Logging {
case _ => null
}

/**
* Given a driver string and an url, return a function that loads the
* specified driver string then returns a connection to the JDBC url.
* getConnector is run on the driver code, while the function it returns
* is run on the executor.
*
* @param driver - The class name of the JDBC driver for the given url, or null if the class name
* is not necessary.
* @param url - The JDBC url to connect to.
*
* @return A function that loads the driver and connects to the url.
*/
def getConnector(driver: String, url: String, properties: Properties): () => Connection = {
() => {
try {
if (driver != null) DriverRegistry.register(driver)
} catch {
case e: ClassNotFoundException =>
logWarning(s"Couldn't find class $driver", e)
}
DriverManager.getConnection(url, properties)
}
}


/**
* Build and return JDBCRDD from the given information.
*
* @param sc - Your SparkContext.
* @param schema - The Catalyst schema of the underlying database table.
* @param driver - The class name of the JDBC driver for the given url.
* @param url - The JDBC url to connect to.
* @param fqTable - The fully-qualified table name (or paren'd SQL query) to use.
* @param requiredColumns - The names of the columns to SELECT.
Expand All @@ -243,7 +219,6 @@ private[sql] object JDBCRDD extends Logging {
def scanTable(
sc: SparkContext,
schema: StructType,
driver: String,
url: String,
properties: Properties,
fqTable: String,
Expand All @@ -254,7 +229,7 @@ private[sql] object JDBCRDD extends Logging {
val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName))
new JDBCRDD(
sc,
getConnector(driver, url, properties),
JdbcUtils.createConnectionFactory(url, properties),
pruneSchema(schema, requiredColumns),
fqTable,
quotedColumns,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,10 @@ private[sql] case class JDBCRelation(
override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)

override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
val driver: String = DriverRegistry.getDriverClassName(url)
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
sqlContext.sparkContext,
schema,
driver,
url,
properties,
table,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.sql.execution.datasources.jdbc

import java.sql.{Connection, PreparedStatement}
import java.sql.{Connection, Driver, DriverManager, PreparedStatement}
import java.util.Properties

import scala.collection.JavaConverters._
import scala.util.Try
import scala.util.control.NonFatal

Expand All @@ -34,10 +35,31 @@ import org.apache.spark.sql.{DataFrame, Row}
object JdbcUtils extends Logging {

/**
* Establishes a JDBC connection.
* Returns a factory for creating connections to the given JDBC URL.
*
* @param url the JDBC url to connect to.
* @param properties JDBC connection properties.
*/
def createConnection(url: String, connectionProperties: Properties): Connection = {
JDBCRDD.getConnector(connectionProperties.getProperty("driver"), url, connectionProperties)()
def createConnectionFactory(url: String, properties: Properties): () => Connection = {
val userSpecifiedDriverClass = Option(properties.getProperty("driver"))
userSpecifiedDriverClass.foreach(DriverRegistry.register)
// Performing this part of the logic on the driver guards against the corner-case where the
// driver returned for a URL is different on the driver and executors due to classpath
// differences.
val driverClass: String = userSpecifiedDriverClass.getOrElse {
DriverManager.getDriver(url).getClass.getCanonicalName
}
() => {
userSpecifiedDriverClass.foreach(DriverRegistry.register)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the one that register the right driver at the executor side, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's right: this function gets shipped to executors, where it's called to create connections.

val driver: Driver = DriverManager.getDrivers.asScala.collectFirst {
case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only real bit of trickiness here and was the part that was missing from my spark-redshift patch.

case d if d.getClass.getCanonicalName == driverClass => d
}.getOrElse {
throw new IllegalStateException(
s"Did not find registered driver with class $driverClass")
}
driver.connect(url, properties)
}
}

/**
Expand Down Expand Up @@ -242,15 +264,14 @@ object JdbcUtils extends Logging {
df: DataFrame,
url: String,
table: String,
properties: Properties = new Properties()) {
properties: Properties) {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
getJdbcType(field.dataType, dialect).jdbcNullType
}

val rddSchema = df.schema
val driver: String = DriverRegistry.getDriverClassName(url)
val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
val getConnection: () => Connection = createConnectionFactory(url, properties)
val batchSize = properties.getProperty("batchsize", "1000").toInt
df.foreachPartition { iterator =>
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect)
Expand Down