diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 0590aec77c3cd..f9342d7eb6582 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} +import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException, Statement} +import java.util.Properties import scala.collection.JavaConverters._ import scala.util.Try @@ -26,11 +27,12 @@ import scala.util.control.NonFatal import org.apache.spark.TaskContext import org.apache.spark.executor.InputMetrics import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -721,15 +723,249 @@ object JdbcUtils extends Logging { ) } + /** + * Check whether a table exists in a given database + * + * @return True if the table exists. + */ + @transient + def checkTableExists(targetDb: String, tableName: String): Boolean = { + val dbc: Connection = DriverManager.getConnection(targetDb) + val dbm = dbc.getMetaData() + // Check if the table exists. If it exists, perform an upsert. + // Otherwise, do a simple dataframe write to the DB + val tables = dbm.getTables(null, null, tableName, null) + val exists = tables.next() // Returns false if next does not exist + dbc.close() + exists + } + + // Provide a reasonable starting batch size for database operations. + private val DEFAULT_BATCH_SIZE: Int = 200 + + // Limit the number of database connections. Some DBs suffer when there are many open + // connections. + private val DEFAULT_MAX_CONNECTIONS: Int = 50 + + private val DEFAULT_ID_COLUMN: String = "id" + + /** + * Given an RDD of SQL queries to execute, connect to a database and perform a + * sequence of batched operations. + * Usies a maximum number of simultaneous connections as defined by "maxConnections". + * + * @param targetDb The database to connect to, provided as a jdbc URL, e.g. + * "jdbc:postgresql://192.168.0.1:5432/MY_DB_NAME?user=MY_USER&password=PASSWORD" + * @param statements An rdd of SQL statements to execute (as strings) + * @param batchSize The batch size to use, default 200 + * @param maxConnections Maximum number of simultaneous connections to open to the database + */ + private def executeStatements(targetDb: String, + statements: DataFrame, + batchSize: Int = DEFAULT_BATCH_SIZE, + maxConnections: Int = DEFAULT_MAX_CONNECTIONS): Unit = { + import statements.sparkSession.implicits._ + + // To avoid overloading database coalesce to a set number of partitions if necessary + val coalesced = if (statements.rdd.getNumPartitions > maxConnections) { + statements.coalesce(maxConnections) + } else { + statements + } + + coalesced.mapPartitions(Iterator(_)).foreach { batch => + val dbc: Connection = DriverManager.getConnection(targetDb) + dbc.setAutoCommit(false) + + val st: Statement = dbc.createStatement() + + try { + batch.grouped(batchSize).foreach { rowBatch => + rowBatch.foreach { statement => + st.addBatch(statement.getString(0)) + } + st.executeBatch() + dbc.commit() + } + } + finally { + dbc.close() + } + } + } + + /** + * Given a row of a database and specified key, generate an insert query + * + * @param row A row of a dataframe + * @param schema The schema for the dataframe + * @param tableName The name of the table to update + * @return A SQL statement that inserts the referenced row into the table. + */ + def genInsertScript(row: Row, schema: StructType, tableName: String): String = { + val schemaString = schema.map(s => s.name).reduce(_ + ", " + _) + + val valString = { + row.toSeq.map(v => "'" + v.toString.replaceAll("'", "''") + "'").reduce(_ + "," + _) + } + + val insert = s"INSERT INTO $tableName ($schemaString) VALUES ($valString);" + insert + } + + /** + * Given a row of a database and specified key, generate an update query + * + * @param row A row of a dataframe + * @param schema The schema for the dataframe + * @param tableName The name of the table to update + * @param keyField The name of the column that can serve as a primary key + * @return A SQL statement that will update the referenced row in the table. + */ + def genUpdateScript(row: Row, schema: StructType, tableName: String, keyField: String): String = { + val keyVal = row.get(schema.fieldIndex(keyField)) + val zipped = row.toSeq.zip(schema.map(s => s.name)) + + val valString = zipped.flatMap(s => { + // Value fields are bounded with single quotes so escape any single quotes in the value + val noQuotes: String = s._1.toString.replaceAll("'", "''") + if (!s._2.equals(keyField)) Seq(s"${s._2}='$noQuotes'") else Seq() + }).reduce(_ + ",\n" + _) + + val update = s"UPDATE $tableName set \n $valString \n WHERE $keyField = $keyVal;" + update + } + + /** + * Write a given DataFrame into a provided database via JDBC by performing an update command. + * If the database contains a row with a matching id, then all other columns will be updated. + * If the database does not contain a row with a matching id an insert is performed instead. + * Note: This command expects that the database contain an indexed column to use for the update, + * default is "id". + * TODO: Add support for arbitrary length keys + * + * @param df The dataframe to write to the database. + * @param targetDb The database to update provided as a jdbc URL, e.g. + * "jdbc:postgresql://192.168.0.1:5432/MY_DB_NAME?user=MY_USER&password=PASSWORD" + * @param tableName The name of the table to update + * @param batchSize The batch size to use, default 200 + * @param maxConnections Maximum number of simultaneous connections to open to the database + * @param idColumn The column to use as the primary key for resolving conflicts, + * default is "id". + */ + @transient + def updateById(df: DataFrame, + targetDb: String, + tableName: String, + batchSize: Int = DEFAULT_BATCH_SIZE, + maxConnections: Int = DEFAULT_MAX_CONNECTIONS, + idColumn: String = DEFAULT_ID_COLUMN): Unit = { + import df.sparkSession.implicits._ + val schema = df.schema + val tableExists = checkTableExists(targetDb, tableName) + + if (!tableExists) { + throw new NotImplementedError("Adding data to a non-existing table is not yet supported.") + } + + val statements = df.map(genUpdateScript(_, schema, tableName, idColumn)).toDF() + + executeStatements(targetDb, statements, batchSize, maxConnections) + } + + /** + * Insert a given DataFrame into a provided database via JDBC. + * + * @param df The dataframe to write to the database. + * @param targetDb The database to update provided as a jdbc URL, e.g. + * "jdbc:postgresql://192.168.0.1:5432/MY_DB_NAME?user=MY_USER&password=PASSWORD" + * @param tableName The name of the table to update + * @param batchSize The batch size to use, default 200 + * @param maxConnections Maximum number of simultaneous connections to open to the database + */ + @transient + def insert(df: DataFrame, + targetDb: String, + tableName: String, + batchSize: Int = 200, + maxConnections: Int = DEFAULT_MAX_CONNECTIONS): Unit = { + import df.sparkSession.implicits._ + val schema = df.schema + val tableExists = checkTableExists(targetDb, tableName) + + if (!tableExists) { + throw new NotImplementedError("Adding data to a non-existing table is not yet supported.") + } + + val statements = df.map(genInsertScript(_, schema, tableName)).toDF() + + executeStatements(targetDb, statements, batchSize, maxConnections) + } + + /** + * Perform an upsert of a DataFrame to a given table. + * Reads an existing table to determine whether any records need to be updated, performs an update + * for any existing records, otherwise performs an insert. + * + * Because an update is an expensive operation, the most efficient way of performing an update on + * a table is to first identify which records should be updated and which can be inserted. If + * possible, an index on a single column should be used to perform the update, for example, an + * "id" column. Performing an update on a multi-field index is even more computationally + * expensive. + * + * For improved performance, remove any indices from the table being updated (except for the + * index on the id column) and restore them after the update. + * + * @param sqlContext The active SQL Context + * @param df The dataset to write to the database. + * @param targetDb The database to update provided as a jdbc URL, e.g. + * "jdbc:postgresql://192.168.0.1:5432/MY_DB_NAME?user=MY_USER&password=PASSWORD" + * @param properties JDBC connection properties. + * @param tableName The name of the table to update + * @param primaryKeys A set representing the primary key for a database - the combination of + * column names that uniquely identifies a row in the dataframe. + * E.g. Set("first_name", "last_name", "address") + * @param batchSize The batch size to use, default 200 + * @param maxConnections Maximum number of simultaneous connections to open to the database + * @param idColumn The column to use as the key for resolving conflicts, default is "id" + */ + def upsert(sqlContext: SQLContext, + df: DataFrame, + targetDb: String, + properties: Properties, + tableName: String, + primaryKeys: Set[String], + batchSize: Int = DEFAULT_BATCH_SIZE, + maxConnections: Int = DEFAULT_MAX_CONNECTIONS, + idColumn: String = "id"): Unit = { + val storedDb = sqlContext.read.jdbc(targetDb, tableName, properties) + + // Determine rows to upsert based on a key match in the database + val toUpsert = df.join(storedDb, primaryKeys.toSeq, "inner").select(df("*"), storedDb("id")) + + // Insert those rows where there is no matching entry in the database already + // Do a select to ensure that columns are in the same order for except + // Note: we need to also get rid of timestamps for this comparison + val upsertKeys = toUpsert.select(primaryKeys.map(col).toSeq: _*) + val primaryKeysToInsert = df.select(primaryKeys.map(col).toSeq: _*).except(upsertKeys) + val toInsert = primaryKeysToInsert.join(df, primaryKeys.toSeq, "left_outer") + + // Only perform an update if there are overlapping elements + if (toUpsert.count() > 0) { + updateById(toUpsert, targetDb, tableName, batchSize, maxConnections, idColumn) + } + + insert(toInsert, targetDb, tableName, batchSize, maxConnections) + } + /** * Creates a table with a given schema. */ - def createTable( - schema: StructType, - url: String, - table: String, - createTableOptions: String, - conn: Connection): Unit = { + def createTable(schema: StructType, + url: String, + table: String, + createTableOptions: String, + conn: Connection): Unit = { val strSchema = schemaString(schema, url) // Create the table if the table does not exist. // To allow certain options to append when create a new table, which can be