diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 11613dd912eca..b91ea8b974eb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -38,6 +38,9 @@ private[sql] case class JDBCPartitioningInfo( numPartitions: Int) private[sql] object JDBCRelation extends Logging { + + import scala.collection.JavaConverters._ + /** * Given a partitioning schematic (a column of integral type, a number of * partitions, and upper and lower bounds on the column's value), generate @@ -99,6 +102,16 @@ private[sql] object JDBCRelation extends Logging { } ans.toArray } + + def getEffectiveProperties( + connectionProperties: Properties, + extraOptions: scala.collection.Map[String, String] = Map()): Properties = { + val props = new Properties() + props.putAll(extraOptions.asJava) + // connectionProperties should override settings in extraOptions + props.putAll(connectionProperties) + props + } } private[sql] case class JDBCRelation( @@ -127,7 +140,7 @@ private[sql] case class JDBCRelation( sparkSession.sparkContext, schema, url, - properties, + JDBCRelation.getEffectiveProperties(properties), table, requiredColumns, filters, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index cb474cbd0ae7e..91f162aafff5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -63,7 +63,7 @@ object JdbcUtils extends Logging { throw new IllegalStateException( s"Did not find registered driver with class $driverClass") } - driver.connect(url, properties) + driver.connect(url, JDBCRelation.getEffectiveProperties(properties)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 995b1200a2294..173b83c2dd8f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -686,6 +686,13 @@ class JDBCSuite extends SparkFunSuite Some(DecimalType(DecimalType.MAX_PRECISION, 10))) } + test("SPARK-10625: JDBC read should allow driver to insert unserializable into properties") { + UnserializableDriverHelper.replaceDriverDuring { + assert(sqlContext.read.jdbc( + urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3) + } + } + test("table exists query by jdbc dialect") { val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index d99b3cf975f4f..a9c643abcc909 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -196,4 +196,12 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count()) assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } + + test("SPARK-10625: JDBC write should allow driver to insert unserializable into properties") { + UnserializableDriverHelper.replaceDriverDuring { + sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/UnserializableDriverHelper.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/UnserializableDriverHelper.scala new file mode 100644 index 0000000000000..ef7a305ff8424 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/UnserializableDriverHelper.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.{Connection, DriverManager} +import java.util.Properties +import java.util.logging.Logger + +object UnserializableDriverHelper { + + import scala.collection.JavaConverters._ + + def replaceDriverDuring[T](f: => T): T = { + object UnserializableH2Driver extends org.h2.Driver { + + override def connect(url: String, info: Properties): Connection = { + + val result = super.connect(url, info) + info.put("unserializableDriver", this) + result + } + + override def getParentLogger: Logger = null + } + + val oldDrivers = DriverManager.getDrivers.asScala.toList.filter(_.acceptsURL("jdbc:h2:")) + oldDrivers.foreach(DriverManager.deregisterDriver) + DriverManager.registerDriver(UnserializableH2Driver) + + val result = try { + f + } finally { + DriverManager.deregisterDriver(UnserializableH2Driver) + oldDrivers.foreach(DriverManager.registerDriver) + } + result + } +}