From e7aa6e7e5618e2a1f5157919322aefd3511caa00 Mon Sep 17 00:00:00 2001 From: tribbloid Date: Tue, 15 Sep 2015 19:02:57 -0400 Subject: [PATCH] test case demonstrating SPARK-10625: Spark SQL JDBC read/write is unable to handle JDBC Drivers that adds unserializable objects into connection properties add one more unit test fix JDBCRelation & DataFrameWriter to pass all tests revise scala style put driver replacement code into a shared function fix styling upgrade to master and resolve all related issues --- .../apache/spark/sql/DataFrameWriter.scala | 20 +++---- .../datasources/jdbc/DriverRegistry.scala | 1 - .../execution/datasources/jdbc/JDBCRDD.scala | 2 +- .../datasources/jdbc/JDBCRelation.scala | 22 ++++++- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 12 +++- .../spark/sql/jdbc/JDBCWriteSuite.scala | 13 ++++- .../sql/jdbc/UnserializableDriverHelper.scala | 58 +++++++++++++++++++ 7 files changed, 105 insertions(+), 23 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/jdbc/UnserializableDriverHelper.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 03867beb78224..a7824e71e38f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -19,16 +19,16 @@ package org.apache.spark.sql import java.util.Properties -import scala.collection.JavaConverters._ - import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project} import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} -import org.apache.spark.sql.catalyst.plans.logical.{Project, InsertIntoTable} -import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCRelation, JdbcUtils} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} import org.apache.spark.sql.sources.HadoopFsRelation +import scala.collection.JavaConverters._ + /** * :: Experimental :: @@ -269,12 +269,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { - val props = new Properties() - extraOptions.foreach { case (key, value) => - props.put(key, value) - } - // connectionProperties should override settings in extraOptions - props.putAll(connectionProperties) + val props = JDBCRelation.getEffectiveProperties(connectionProperties, this.extraOptions) val conn = JdbcUtils.createConnection(url, props) try { @@ -303,7 +298,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { conn.close() } - JdbcUtils.saveTable(df, url, table, props) + val propsForSave = JDBCRelation.getEffectiveProperties(connectionProperties, this.extraOptions) + JdbcUtils.saveTable(df, url, table, propsForSave) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala index 7ccd61ed469e9..72d155cad846f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala @@ -57,4 +57,3 @@ object DriverRegistry extends Logging { case driver => driver.getClass.getCanonicalName } } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 57a8a044a37cd..f5f7094241e5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -187,7 +187,7 @@ private[sql] object JDBCRDD extends Logging { case e: ClassNotFoundException => logWarning(s"Couldn't find class $driver", e) } - DriverManager.getConnection(url, properties) + DriverManager.getConnection(url, JDBCRelation.getEffectiveProperties(properties)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index f9300dc2cb529..195a5e5a97557 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -75,6 +75,18 @@ private[sql] object JDBCRelation { } ans.toArray } + + def getEffectiveProperties( + connectionProperties: Properties, + extraOptions: scala.collection.Map[String, String] = Map()): Properties = { + val props = new Properties() + extraOptions.foreach { case (key, value) => + props.put(key, value) + } + // connectionProperties should override settings in extraOptions + props.putAll(connectionProperties) + props + } } private[sql] case class JDBCRelation( @@ -88,7 +100,11 @@ private[sql] case class JDBCRelation( override val needConversion: Boolean = false - override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) + override val schema: StructType = JDBCRDD.resolveTable( + url, + table, + JDBCRelation.getEffectiveProperties(properties) + ) override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val driver: String = DriverRegistry.getDriverClassName(url) @@ -98,7 +114,7 @@ private[sql] case class JDBCRelation( schema, driver, url, - properties, + JDBCRelation.getEffectiveProperties(properties), table, requiredColumns, filters, @@ -108,6 +124,6 @@ private[sql] case class JDBCRelation( override def insert(data: DataFrame, overwrite: Boolean): Unit = { data.write .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) - .jdbc(url, table, properties) + .jdbc(url, table, JDBCRelation.getEffectiveProperties(properties)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d530b1a469ce2..cb9f4b67b6ce5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -21,13 +21,12 @@ import java.math.BigDecimal import java.sql.DriverManager import java.util.{Calendar, GregorianCalendar, Properties} -import org.h2.jdbc.JdbcSQLException -import org.scalatest.BeforeAndAfter - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +import org.h2.jdbc.JdbcSQLException +import org.scalatest.BeforeAndAfter class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { import testImplicits._ @@ -469,6 +468,13 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(derbyDialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "BOOLEAN") } + test("Basic API with Unserializable Driver Properties") { + UnserializableDriverHelper.replaceDriverDuring { + assert(sqlContext.read.jdbc( + urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3) + } + } + test("table exists query by jdbc dialect") { val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e23ee6693133b..f3ea22d81f408 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -20,12 +20,11 @@ package org.apache.spark.sql.jdbc import java.sql.DriverManager import java.util.Properties -import org.scalatest.BeforeAndAfter - -import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.util.Utils +import org.scalatest.BeforeAndAfter class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { @@ -151,4 +150,12 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } + + test("INSERT to JDBC Datasource with Unserializable Driver Properties") { + UnserializableDriverHelper.replaceDriverDuring { + sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/UnserializableDriverHelper.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/UnserializableDriverHelper.scala new file mode 100644 index 0000000000000..737ba8366c605 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/UnserializableDriverHelper.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.{DriverManager, Connection} +import java.util.Properties +import java.util.logging.Logger + +object UnserializableDriverHelper { + + def replaceDriverDuring[T](f: => T): T = { + import scala.collection.JavaConverters._ + + object UnserializableH2Driver extends org.h2.Driver { + + override def connect(url: String, info: Properties): Connection = { + + val result = super.connect(url, info) + info.put("unserializableDriver", this) + result + } + + override def getParentLogger: Logger = null + } + + val oldDrivers = DriverManager.getDrivers.asScala.filter(_.acceptsURL("jdbc:h2:")).toSeq + oldDrivers.foreach{ + DriverManager.deregisterDriver + } + DriverManager.registerDriver(UnserializableH2Driver) + + val result = try { + f + } + finally { + DriverManager.deregisterDriver(UnserializableH2Driver) + oldDrivers.foreach{ + DriverManager.registerDriver + } + } + result + } +}