diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b9be7a7545ef8..dde2a9b708c74 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1086,6 +1086,13 @@ the following case-sensitive options: + + maxConnections + + The maximum number of concurrent JDBC connections that can be used, if set. Only applies when writing. It works by limiting the operation's parallelism, which depends on the input's partition count. If its partition count exceeds this limit, the operation will coalesce the input to fewer partitions before writing. + + + isolationLevel diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 7f419b5788c4f..d416eec6ddaec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -122,6 +122,11 @@ class JDBCOptions( case "REPEATABLE_READ" => Connection.TRANSACTION_REPEATABLE_READ case "SERIALIZABLE" => Connection.TRANSACTION_SERIALIZABLE } + // the maximum number of connections + val maxConnections = parameters.get(JDBC_MAX_CONNECTIONS).map(_.toInt) + require(maxConnections.isEmpty || maxConnections.get > 0, + s"Invalid value `${maxConnections.get}` for parameter `$JDBC_MAX_CONNECTIONS`. " + + "The minimum value is 1.") } object JDBCOptions { @@ -144,4 +149,5 @@ object JDBCOptions { val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") val JDBC_BATCH_INSERT_SIZE = newOption("batchsize") val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel") + val JDBC_MAX_CONNECTIONS = newOption("maxConnections") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 41edb6511c2ce..cdc3c99daa1ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -667,7 +667,14 @@ object JdbcUtils extends Logging { val getConnection: () => Connection = createConnectionFactory(options) val batchSize = options.batchSize val isolationLevel = options.isolationLevel - df.foreachPartition(iterator => savePartition( + val maxConnections = options.maxConnections + val repartitionedDF = + if (maxConnections.isDefined && maxConnections.get < df.rdd.getNumPartitions) { + df.coalesce(maxConnections.get) + } else { + df + } + repartitionedDF.foreachPartition(iterator => savePartition( getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e3d3c6c3a887c..5795b4d860cb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -312,4 +312,16 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { .options(properties.asScala) .save() } + + test("SPARK-18413: Add `maxConnections` JDBCOption") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val e = intercept[IllegalArgumentException] { + df.write.format("jdbc") + .option("dbtable", "TEST.SAVETEST") + .option("url", url1) + .option(s"${JDBCOptions.JDBC_MAX_CONNECTIONS}", "0") + .save() + }.getMessage + assert(e.contains("Invalid value `0` for parameter `maxConnections`. The minimum value is 1")) + } }