Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ package org.apache.spark.sql

import java.util.Properties

import scala.collection.JavaConverters._

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project}
import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation}
import org.apache.spark.sql.catalyst.plans.logical.{Project, InsertIntoTable}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCRelation, JdbcUtils}
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource}
import org.apache.spark.sql.sources.HadoopFsRelation

import scala.collection.JavaConverters._


/**
* :: Experimental ::
Expand Down Expand Up @@ -269,12 +269,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
def jdbc(url: String, table: String, connectionProperties: Properties): Unit = {
val props = new Properties()
extraOptions.foreach { case (key, value) =>
props.put(key, value)
}
// connectionProperties should override settings in extraOptions
props.putAll(connectionProperties)
val props = JDBCRelation.getEffectiveProperties(connectionProperties, this.extraOptions)
val conn = JdbcUtils.createConnection(url, props)

try {
Expand Down Expand Up @@ -303,7 +298,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {
conn.close()
}

JdbcUtils.saveTable(df, url, table, props)
val propsForSave = JDBCRelation.getEffectiveProperties(connectionProperties, this.extraOptions)
JdbcUtils.saveTable(df, url, table, propsForSave)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,3 @@ object DriverRegistry extends Logging {
case driver => driver.getClass.getCanonicalName
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ private[sql] object JDBCRDD extends Logging {
case e: ClassNotFoundException =>
logWarning(s"Couldn't find class $driver", e)
}
DriverManager.getConnection(url, properties)
DriverManager.getConnection(url, JDBCRelation.getEffectiveProperties(properties))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ private[sql] object JDBCRelation {
}
ans.toArray
}

def getEffectiveProperties(
connectionProperties: Properties,
extraOptions: scala.collection.Map[String, String] = Map()): Properties = {
val props = new Properties()
extraOptions.foreach { case (key, value) =>
props.put(key, value)
}
// connectionProperties should override settings in extraOptions
props.putAll(connectionProperties)
props
}
}

private[sql] case class JDBCRelation(
Expand All @@ -88,7 +100,11 @@ private[sql] case class JDBCRelation(

override val needConversion: Boolean = false

override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)
override val schema: StructType = JDBCRDD.resolveTable(
url,
table,
JDBCRelation.getEffectiveProperties(properties)
)

override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
val driver: String = DriverRegistry.getDriverClassName(url)
Expand All @@ -98,7 +114,7 @@ private[sql] case class JDBCRelation(
schema,
driver,
url,
properties,
JDBCRelation.getEffectiveProperties(properties),
table,
requiredColumns,
filters,
Expand All @@ -108,6 +124,6 @@ private[sql] case class JDBCRelation(
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
data.write
.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
.jdbc(url, table, properties)
.jdbc(url, table, JDBCRelation.getEffectiveProperties(properties))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@ import java.math.BigDecimal
import java.sql.DriverManager
import java.util.{Calendar, GregorianCalendar, Properties}

import org.h2.jdbc.JdbcSQLException
import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.h2.jdbc.JdbcSQLException
import org.scalatest.BeforeAndAfter

class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
import testImplicits._
Expand Down Expand Up @@ -469,6 +468,13 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext
assert(derbyDialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "BOOLEAN")
}

test("Basic API with Unserializable Driver Properties") {
UnserializableDriverHelper.replaceDriverDuring {
assert(sqlContext.read.jdbc(
urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3)
}
}

test("table exists query by jdbc dialect") {
val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db")
val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@ package org.apache.spark.sql.jdbc
import java.sql.DriverManager
import java.util.Properties

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.util.Utils
import org.scalatest.BeforeAndAfter

class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {

Expand Down Expand Up @@ -151,4 +150,12 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}

test("INSERT to JDBC Datasource with Unserializable Driver Properties") {
UnserializableDriverHelper.replaceDriverDuring {
sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.jdbc

import java.sql.{DriverManager, Connection}
import java.util.Properties
import java.util.logging.Logger

object UnserializableDriverHelper {

def replaceDriverDuring[T](f: => T): T = {
import scala.collection.JavaConverters._

object UnserializableH2Driver extends org.h2.Driver {

override def connect(url: String, info: Properties): Connection = {

val result = super.connect(url, info)
info.put("unserializableDriver", this)
result
}

override def getParentLogger: Logger = null
}

val oldDrivers = DriverManager.getDrivers.asScala.filter(_.acceptsURL("jdbc:h2:")).toSeq
oldDrivers.foreach{
DriverManager.deregisterDriver
}
DriverManager.registerDriver(UnserializableH2Driver)

val result = try {
f
}
finally {
DriverManager.deregisterDriver(UnserializableH2Driver)
oldDrivers.foreach{
DriverManager.registerDriver
}
}
result
}
}