-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-19335] Introduce insert, update, and upsert commands to the JdbcUtils class #16685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8d499fe
89cef37
e1fc6f6
a64719b
ca494eb
6a2cb05
69f6939
7938277
56545ed
c6af861
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,8 @@ | |
|
|
||
| package org.apache.spark.sql.execution.datasources.jdbc | ||
|
|
||
| import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} | ||
| import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException, Statement} | ||
| import java.util.Properties | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
| import scala.util.Try | ||
|
|
@@ -26,11 +27,12 @@ import scala.util.control.NonFatal | |
| import org.apache.spark.TaskContext | ||
| import org.apache.spark.executor.InputMetrics | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.sql.{AnalysisException, DataFrame, Row} | ||
| import org.apache.spark.sql._ | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.encoders.RowEncoder | ||
| import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow | ||
| import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.unsafe.types.UTF8String | ||
|
|
@@ -721,15 +723,249 @@ object JdbcUtils extends Logging { | |
| ) | ||
| } | ||
|
|
||
| /** | ||
| * Check whether a table exists in a given database | ||
| * | ||
| * @return True if the table exists. | ||
| */ | ||
| @transient | ||
| def checkTableExists(targetDb: String, tableName: String): Boolean = { | ||
| val dbc: Connection = DriverManager.getConnection(targetDb) | ||
| val dbm = dbc.getMetaData() | ||
| // Check if the table exists. If it exists, perform an upsert. | ||
| // Otherwise, do a simple dataframe write to the DB | ||
| val tables = dbm.getTables(null, null, tableName, null) | ||
| val exists = tables.next() // Returns false if next does not exist | ||
| dbc.close() | ||
| exists | ||
| } | ||
|
|
||
| // Provide a reasonable starting batch size for database operations. | ||
| private val DEFAULT_BATCH_SIZE: Int = 200 | ||
|
|
||
| // Limit the number of database connections. Some DBs suffer when there are many open | ||
| // connections. | ||
| private val DEFAULT_MAX_CONNECTIONS: Int = 50 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, since Spark 2.1, we already provide the parm for limiting the max num of concurrent JDBC connection when inserting data to JDBC tables. The parm is
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A large DataFrame may have hundreds or thousands of partitions. Having a separate value not tied to partitions can facilitate working with databases with reduced allowable connections or ones where performance degrades if hundreds of thousands of connections are allowed, e.g. Postgres. If you're suggesting purely changing the default to be numPartitions, I think that's wholly valid but I'm not aware of how to access that value in a static context.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please see the logics here. We already can do it for Insert by using
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it, thanks! |
||
|
|
||
| private val DEFAULT_ID_COLUMN: String = "id" | ||
|
|
||
| /** | ||
| * Given an RDD of SQL queries to execute, connect to a database and perform a | ||
| * sequence of batched operations. | ||
| * Usies a maximum number of simultaneous connections as defined by "maxConnections". | ||
| * | ||
| * @param targetDb The database to connect to, provided as a jdbc URL, e.g. | ||
| * "jdbc:postgresql://192.168.0.1:5432/MY_DB_NAME?user=MY_USER&password=PASSWORD" | ||
| * @param statements An rdd of SQL statements to execute (as strings) | ||
| * @param batchSize The batch size to use, default 200 | ||
| * @param maxConnections Maximum number of simultaneous connections to open to the database | ||
| */ | ||
| private def executeStatements(targetDb: String, | ||
| statements: DataFrame, | ||
| batchSize: Int = DEFAULT_BATCH_SIZE, | ||
| maxConnections: Int = DEFAULT_MAX_CONNECTIONS): Unit = { | ||
| import statements.sparkSession.implicits._ | ||
|
|
||
| // To avoid overloading database coalesce to a set number of partitions if necessary | ||
| val coalesced = if (statements.rdd.getNumPartitions > maxConnections) { | ||
| statements.coalesce(maxConnections) | ||
| } else { | ||
| statements | ||
| } | ||
|
|
||
| coalesced.mapPartitions(Iterator(_)).foreach { batch => | ||
| val dbc: Connection = DriverManager.getConnection(targetDb) | ||
| dbc.setAutoCommit(false) | ||
|
|
||
| val st: Statement = dbc.createStatement() | ||
|
|
||
| try { | ||
| batch.grouped(batchSize).foreach { rowBatch => | ||
| rowBatch.foreach { statement => | ||
| st.addBatch(statement.getString(0)) | ||
| } | ||
| st.executeBatch() | ||
| dbc.commit() | ||
| } | ||
| } | ||
| finally { | ||
| dbc.close() | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Given a row of a database and specified key, generate an insert query | ||
| * | ||
| * @param row A row of a dataframe | ||
| * @param schema The schema for the dataframe | ||
| * @param tableName The name of the table to update | ||
| * @return A SQL statement that inserts the referenced row into the table. | ||
| */ | ||
| def genInsertScript(row: Row, schema: StructType, tableName: String): String = { | ||
| val schemaString = schema.map(s => s.name).reduce(_ + ", " + _) | ||
|
|
||
| val valString = { | ||
| row.toSeq.map(v => "'" + v.toString.replaceAll("'", "''") + "'").reduce(_ + "," + _) | ||
| } | ||
|
|
||
| val insert = s"INSERT INTO $tableName ($schemaString) VALUES ($valString);" | ||
| insert | ||
| } | ||
|
|
||
| /** | ||
| * Given a row of a database and specified key, generate an update query | ||
| * | ||
| * @param row A row of a dataframe | ||
| * @param schema The schema for the dataframe | ||
| * @param tableName The name of the table to update | ||
| * @param keyField The name of the column that can serve as a primary key | ||
| * @return A SQL statement that will update the referenced row in the table. | ||
| */ | ||
| def genUpdateScript(row: Row, schema: StructType, tableName: String, keyField: String): String = { | ||
| val keyVal = row.get(schema.fieldIndex(keyField)) | ||
| val zipped = row.toSeq.zip(schema.map(s => s.name)) | ||
|
|
||
| val valString = zipped.flatMap(s => { | ||
| // Value fields are bounded with single quotes so escape any single quotes in the value | ||
| val noQuotes: String = s._1.toString.replaceAll("'", "''") | ||
| if (!s._2.equals(keyField)) Seq(s"${s._2}='$noQuotes'") else Seq() | ||
| }).reduce(_ + ",\n" + _) | ||
|
|
||
| val update = s"UPDATE $tableName set \n $valString \n WHERE $keyField = $keyVal;" | ||
| update | ||
| } | ||
|
|
||
| /** | ||
| * Write a given DataFrame into a provided database via JDBC by performing an update command. | ||
| * If the database contains a row with a matching id, then all other columns will be updated. | ||
| * If the database does not contain a row with a matching id an insert is performed instead. | ||
| * Note: This command expects that the database contain an indexed column to use for the update, | ||
| * default is "id". | ||
| * TODO: Add support for arbitrary length keys | ||
| * | ||
| * @param df The dataframe to write to the database. | ||
| * @param targetDb The database to update provided as a jdbc URL, e.g. | ||
| * "jdbc:postgresql://192.168.0.1:5432/MY_DB_NAME?user=MY_USER&password=PASSWORD" | ||
| * @param tableName The name of the table to update | ||
| * @param batchSize The batch size to use, default 200 | ||
| * @param maxConnections Maximum number of simultaneous connections to open to the database | ||
| * @param idColumn The column to use as the primary key for resolving conflicts, | ||
| * default is "id". | ||
| */ | ||
| @transient | ||
| def updateById(df: DataFrame, | ||
| targetDb: String, | ||
| tableName: String, | ||
| batchSize: Int = DEFAULT_BATCH_SIZE, | ||
| maxConnections: Int = DEFAULT_MAX_CONNECTIONS, | ||
| idColumn: String = DEFAULT_ID_COLUMN): Unit = { | ||
| import df.sparkSession.implicits._ | ||
| val schema = df.schema | ||
| val tableExists = checkTableExists(targetDb, tableName) | ||
|
|
||
| if (!tableExists) { | ||
| throw new NotImplementedError("Adding data to a non-existing table is not yet supported.") | ||
| } | ||
|
|
||
| val statements = df.map(genUpdateScript(_, schema, tableName, idColumn)).toDF() | ||
|
|
||
| executeStatements(targetDb, statements, batchSize, maxConnections) | ||
| } | ||
|
|
||
| /** | ||
| * Insert a given DataFrame into a provided database via JDBC. | ||
| * | ||
| * @param df The dataframe to write to the database. | ||
| * @param targetDb The database to update provided as a jdbc URL, e.g. | ||
| * "jdbc:postgresql://192.168.0.1:5432/MY_DB_NAME?user=MY_USER&password=PASSWORD" | ||
| * @param tableName The name of the table to update | ||
| * @param batchSize The batch size to use, default 200 | ||
| * @param maxConnections Maximum number of simultaneous connections to open to the database | ||
| */ | ||
| @transient | ||
| def insert(df: DataFrame, | ||
| targetDb: String, | ||
| tableName: String, | ||
| batchSize: Int = 200, | ||
| maxConnections: Int = DEFAULT_MAX_CONNECTIONS): Unit = { | ||
| import df.sparkSession.implicits._ | ||
| val schema = df.schema | ||
| val tableExists = checkTableExists(targetDb, tableName) | ||
|
|
||
| if (!tableExists) { | ||
| throw new NotImplementedError("Adding data to a non-existing table is not yet supported.") | ||
| } | ||
|
|
||
| val statements = df.map(genInsertScript(_, schema, tableName)).toDF() | ||
|
|
||
| executeStatements(targetDb, statements, batchSize, maxConnections) | ||
| } | ||
|
|
||
| /** | ||
| * Perform an upsert of a DataFrame to a given table. | ||
| * Reads an existing table to determine whether any records need to be updated, performs an update | ||
| * for any existing records, otherwise performs an insert. | ||
| * | ||
| * Because an update is an expensive operation, the most efficient way of performing an update on | ||
| * a table is to first identify which records should be updated and which can be inserted. If | ||
| * possible, an index on a single column should be used to perform the update, for example, an | ||
| * "id" column. Performing an update on a multi-field index is even more computationally | ||
| * expensive. | ||
| * | ||
| * For improved performance, remove any indices from the table being updated (except for the | ||
| * index on the id column) and restore them after the update. | ||
| * | ||
| * @param sqlContext The active SQL Context | ||
| * @param df The dataset to write to the database. | ||
| * @param targetDb The database to update provided as a jdbc URL, e.g. | ||
| * "jdbc:postgresql://192.168.0.1:5432/MY_DB_NAME?user=MY_USER&password=PASSWORD" | ||
| * @param properties JDBC connection properties. | ||
| * @param tableName The name of the table to update | ||
| * @param primaryKeys A set representing the primary key for a database - the combination of | ||
| * column names that uniquely identifies a row in the dataframe. | ||
| * E.g. Set("first_name", "last_name", "address") | ||
| * @param batchSize The batch size to use, default 200 | ||
| * @param maxConnections Maximum number of simultaneous connections to open to the database | ||
| * @param idColumn The column to use as the key for resolving conflicts, default is "id" | ||
| */ | ||
| def upsert(sqlContext: SQLContext, | ||
| df: DataFrame, | ||
| targetDb: String, | ||
| properties: Properties, | ||
| tableName: String, | ||
| primaryKeys: Set[String], | ||
| batchSize: Int = DEFAULT_BATCH_SIZE, | ||
| maxConnections: Int = DEFAULT_MAX_CONNECTIONS, | ||
| idColumn: String = "id"): Unit = { | ||
| val storedDb = sqlContext.read.jdbc(targetDb, tableName, properties) | ||
|
|
||
| // Determine rows to upsert based on a key match in the database | ||
| val toUpsert = df.join(storedDb, primaryKeys.toSeq, "inner").select(df("*"), storedDb("id")) | ||
|
|
||
| // Insert those rows where there is no matching entry in the database already | ||
| // Do a select to ensure that columns are in the same order for except | ||
| // Note: we need to also get rid of timestamps for this comparison | ||
| val upsertKeys = toUpsert.select(primaryKeys.map(col).toSeq: _*) | ||
| val primaryKeysToInsert = df.select(primaryKeys.map(col).toSeq: _*).except(upsertKeys) | ||
| val toInsert = primaryKeysToInsert.join(df, primaryKeys.toSeq, "left_outer") | ||
|
|
||
| // Only perform an update if there are overlapping elements | ||
| if (toUpsert.count() > 0) { | ||
| updateById(toUpsert, targetDb, tableName, batchSize, maxConnections, idColumn) | ||
| } | ||
|
|
||
| insert(toInsert, targetDb, tableName, batchSize, maxConnections) | ||
| } | ||
|
|
||
| /** | ||
| * Creates a table with a given schema. | ||
| */ | ||
| def createTable( | ||
| schema: StructType, | ||
| url: String, | ||
| table: String, | ||
| createTableOptions: String, | ||
| conn: Connection): Unit = { | ||
| def createTable(schema: StructType, | ||
| url: String, | ||
| table: String, | ||
| createTableOptions: String, | ||
| conn: Connection): Unit = { | ||
| val strSchema = schemaString(schema, url) | ||
| // Create the table if the table does not exist. | ||
| // To allow certain options to append when create a new table, which can be | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Up to my knowledge, this is not the public API (see execution/package.scala#L21-L23).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh heck - this did seem like the appropriate place to put this though. Any thoughts on where it could live instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is worth being added (i am not supposed to decide this), I think we should make this working together within current Spark APIs or new Spark APIs cc @gatorsmile who I guess knows this area better.