diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json
index 8962cc3821f36..dd7ee5dd68410 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -6369,6 +6369,18 @@
],
"sqlState" : "42K0E"
},
+ "UPSERT_KEY_COLUMNS_REQUIRED" : {
+ "message" : [
+ "Upsert requires key columns."
+ ],
+ "sqlState" : "42703"
+ },
+ "UPSERT_NOT_ALLOWED" : {
+ "message" : [
+ "Upsert is not allowed for table
: ."
+ ],
+ "sqlState" : "0A000"
+ },
"USER_DEFINED_FUNCTIONS" : {
"message" : [
"User defined function is invalid:"
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
index 62f088ebc2b6d..4a2ce7aa759ba 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
@@ -40,7 +40,7 @@ import org.apache.spark.tags.DockerTest
* }}}
*/
@DockerTest
-class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
+class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTests {
override val db = new MsSQLServerDatabaseOnDocker
override def dataPreparation(conn: Connection): Unit = {
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
index cf547b93aa0ba..e123267d9716c 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
@@ -61,7 +61,8 @@ import org.apache.spark.tags.DockerTest
* and with Oracle Express Edition versions 18.4.0 and 21.4.0
*/
@DockerTest
-class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSparkSession {
+class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSparkSession
+ with UpsertTests {
import testImplicits._
override val db = new OracleDatabaseOnDocker
@@ -182,6 +183,13 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSpark
conn.commit()
}
+ // Oracle syntax for timestamps is special, need to patch UpsertTests test data
+ override def getUpsertTestTableInserts(tableName: String): Seq[String] = {
+ super.getUpsertTestTableInserts(tableName).map(sql =>
+ sql.replace(", '1996-", ", TIMESTAMP '1996-")
+ )
+ }
+
test("SPARK-16625: Importing Oracle numeric types") {
Seq("true", "false").foreach { flag =>
withSQLConf((SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED.key, flag)) {
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
index 5c985da226b06..d28b9a9827278 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
@@ -40,7 +40,7 @@ import org.apache.spark.tags.DockerTest
* }}}
*/
@DockerTest
-class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
+class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTests {
override val db = new PostgresDatabaseOnDocker
override def dataPreparation(conn: Connection): Unit = {
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala
new file mode 100644
index 0000000000000..3c666809a4fd9
--- /dev/null
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.sql.Timestamp
+import java.util.{Properties, UUID}
+
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.sql.{Row, SaveMode}
+import org.apache.spark.sql.functions.{col, lit, rand, when}
+import org.apache.spark.sql.types.{DoubleType, LongType}
+
+// schema of upsert test table
+case class Upsert(id: Long, ts: Timestamp, v1: Double, v2: Option[Double])
+
+trait UpsertTests {
+ self: DockerJDBCIntegrationSuite =>
+
+ import testImplicits._
+
+ def newTestTableName(): String = "upsert" + UUID.randomUUID().toString.replaceAll("-", "")
+
+ def getUpsertTestTableInserts(tableName: String): Seq[String] =
+ Seq(
+ s"INSERT INTO $tableName VALUES (1, '1996-01-01 01:23:45', 1.234, 123.456)",
+ s"INSERT INTO $tableName VALUES (1, '1996-01-01 01:23:46', 1.235, 123.457)",
+ s"INSERT INTO $tableName VALUES (2, '1996-01-01 01:23:45', 2.345, 234.567)",
+ s"INSERT INTO $tableName VALUES (2, '1996-01-01 01:23:46', 2.346, 234.568)"
+ )
+
+ def createUpsertTestTable(tableName: String): Unit = {
+ // get jdbc connection
+ val conn = eventually(connectionTimeout, interval(1.second)) {
+ getConnection()
+ }
+
+ // create test table by writing empty dataset to it
+ spark.emptyDataset[Upsert].write.options(Map(
+ "upsert" -> "true",
+ "upsertKeyColumns" -> "id, ts"
+ )).jdbc(jdbcUrl, tableName, new Properties)
+
+ // insert test data
+ try {
+ getUpsertTestTableInserts(tableName).foreach(sql =>
+ conn.prepareStatement(sql).executeUpdate())
+ } finally {
+ conn.close()
+ }
+ }
+
+ test(s"Upsert existing table") { doTestUpsert(tableExists = true) }
+ test(s"Upsert non-existing table") { doTestUpsert(tableExists = false) }
+
+ Seq(
+ Seq("ts", "id", "v1", "v2"),
+ Seq("ts", "v1", "id", "v2"),
+ Seq("ts", "v1", "v2", "id"),
+ Seq("v2", "v1", "ts", "id")
+ ).foreach { columns =>
+ test(s"Upsert with varying column order - ${columns.mkString(",")}") {
+ doTestUpsert(tableExists = true, Some(columns))
+ }
+ }
+
+ test(s"Upsert with column subset") {
+ doTestUpsert(tableExists = true, Some(Seq("id", "ts", "v1")))
+ }
+
+ def doTestUpsert(tableExists: Boolean, project: Option[Seq[String]] = None): Unit = {
+ // either project is None, or it contains all of Seq("id", "ts")
+ assert(project.forall(p => Seq("id", "ts").forall(p.contains)))
+ val df = Seq(
+ (1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, 123.457), // row unchanged
+ (2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, 234.567), // updates v1
+ (2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, 234.569), // updates v1 and v2
+ (3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, 345.678) // inserts new row
+ ).toDF("id", "ts", "v1", "v2").repartition(10)
+
+ val table = newTestTableName()
+ if (tableExists) { createUpsertTestTable(table) }
+ val options = Map(
+ "numPartitions" -> "10", "upsert" -> "true", "upsertKeyColumns" -> "id, ts"
+ )
+ project.map(columns => df.select(columns.map(col): _*)).getOrElse(df)
+ .write.mode(SaveMode.Append).options(options).jdbc(jdbcUrl, table, new Properties)
+
+ val actual = spark.read.jdbc(jdbcUrl, table, new Properties).sort("id", "ts")
+ // required by OracleIntegrationSuite
+ .select($"id".cast(LongType), $"ts", $"v1".cast(DoubleType), $"v2".cast(DoubleType))
+ .collect()
+ val existing = if (tableExists) {
+ Seq((1, Timestamp.valueOf("1996-01-01 01:23:45"), 1.234, Some(123.456)))
+ } else {
+ Seq.empty
+ }
+ // either project is None, or it contains all of Seq("v1", "v2")
+ val upsertedRows = if (project.forall(p => Seq("v1", "v2").forall(p.contains))) {
+ Seq(
+ (1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, Some(123.457)),
+ (2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, Some(234.567)),
+ (2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, Some(234.569)),
+ (3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, Some(345.678))
+ )
+ } else if (project.exists(!_.contains("v2"))) {
+ // column v2 not updated
+ Seq(
+ (1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, Some(123.457)),
+ (2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, Some(234.567)),
+ (2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, Some(234.568)),
+ (3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, None)
+ )
+ } else {
+ throw new RuntimeException("Unsupported test case")
+ }
+ val expected = (existing ++ upsertedRows).map { case (id, ts, v1, v2) =>
+ Row(Integer.valueOf(id), ts, v1.doubleValue(), v2.map(_.doubleValue()).orNull)
+ }
+ assert(actual === expected)
+ }
+
+ test(s"Upsert concurrency") {
+ val table = newTestTableName()
+
+ // create a table with 100k rows
+ val init =
+ spark.range(100000)
+ .withColumn("ts", lit(Timestamp.valueOf("2023-06-07 12:34:56")))
+ .withColumn("v", rand())
+
+ // upsert 100 batches of 100 rows each
+ // run in 32 tasks
+ val patch =
+ spark.range(100)
+ .join(spark.range(100).select(($"id" * 1000).as("offset")))
+ .repartition(32)
+ .select(
+ ($"id" + $"offset").as("id"),
+ lit(Timestamp.valueOf("2023-06-07 12:34:56")).as("ts"),
+ lit(-1.0).as("v")
+ )
+
+ init
+ .write
+ .mode(SaveMode.Overwrite)
+ .options(Map("upsert" -> "true", "upsertKeyColumns" -> "id, ts"))
+ .jdbc(jdbcUrl, table, new Properties)
+
+ patch
+ .write
+ .mode(SaveMode.Append)
+ .options(Map("upsert" -> "true", "upsertKeyColumns" -> "id, ts"))
+ .jdbc(jdbcUrl, table, new Properties)
+
+ // check result table has 100*100 updated rows
+ val result = spark.read.jdbc(jdbcUrl, table, new Properties)
+ .select($"id", when($"v" === -1.0, true).otherwise(false).as("patched"))
+ .groupBy($"patched")
+ .count()
+ .sort($"patched")
+ .as[(Boolean, Long)]
+ .collect()
+ assert(result === Seq((false, 90000), (true, 10000)))
+ }
+
+ test("Upsert with columns that require quotes") {}
+ test("Upsert with table name that requires quotes") {}
+ test("Upsert null values") {}
+
+ test("Write with unspecified mode with upsert") {}
+ test("Write with overwrite mode with upsert") {}
+ test("Write with error-if-exists mode with upsert") {}
+ test("Write with ignore mode with upsert") {}
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index b58605ae95420..ae945b7b8be84 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -1576,12 +1576,15 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
messageParameters = Map("changes" -> changes.toString()))
}
- private def tableDoesNotSupportError(cmd: String, table: Table): Throwable = {
+ private def tableDoesNotSupportError(cmd: String, table: Table): Throwable =
+ tableNameDoesNotSupportError(cmd, table.name)
+
+ private def tableNameDoesNotSupportError(cmd: String, table: String): Throwable = {
new AnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_1121",
messageParameters = Map(
"cmd" -> cmd,
- "table" -> table.name))
+ "table" -> table))
}
def tableDoesNotSupportReadsError(table: Table): Throwable = {
@@ -1608,6 +1611,22 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
tableDoesNotSupportError("atomic partition management", table)
}
+ def tableDoesNotSupportUpsertError(table: String, dialect: String): Throwable = {
+ tableNameDoesNotSupportError(s"upsert in dialect $dialect", table)
+ }
+
+ def upsertKeyColumnsRequiredError(): Throwable = {
+ new AnalysisException(
+ errorClass = "UPSERT_KEY_COLUMNS_REQUIRED",
+ messageParameters = Map.empty)
+ }
+
+ def upsertNotAllowedError(table: String, reason: String): Throwable = {
+ new AnalysisException(
+ errorClass = "UPSERT_NOT_ALLOWED",
+ messageParameters = Map("table" -> table, "reason" -> reason))
+ }
+
def tableIsNotRowLevelOperationTableError(table: Table): Throwable = {
new AnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_1122",
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index f0c638b7d07c8..4b83e9bd7e9e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -169,6 +169,13 @@ class JDBCOptions(
val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean
val isCascadeTruncate: Option[Boolean] = parameters.get(JDBC_CASCADE_TRUNCATE).map(_.toBoolean)
+
+ // if to upsert the table in the JDBC database
+ val isUpsert = parameters.getOrElse(JDBC_UPSERT, "false").toBoolean
+ // the columns used to identify update and insert rows in upsert mode
+ val upsertKeyColumns = parameters.get(JDBC_UPSERT_KEY_COLUMNS).map(_.split(",").map(_.trim))
+ .getOrElse(Array.empty)
+
// the create table option , which can be table_options or partition_options.
// E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
// TODO: to reuse the existing partition parameters for those partition specific options
@@ -310,6 +317,8 @@ object JDBCOptions {
val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize")
val JDBC_TRUNCATE = newOption("truncate")
val JDBC_CASCADE_TRUNCATE = newOption("cascadeTruncate")
+ val JDBC_UPSERT = newOption("upsert")
+ val JDBC_UPSERT_KEY_COLUMNS = newOption("upsertKeyColumns")
val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions")
val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes")
val JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES = newOption("customSchema")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
index d9be1a1e3f674..7f431da7f4e09 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
@@ -68,7 +68,11 @@ class JdbcRelationProvider extends CreatableRelationProvider
case SaveMode.Append =>
val tableSchema = JdbcUtils.getSchemaOption(conn, options)
- saveTable(df, tableSchema, isCaseSensitive, options)
+ if (options.isUpsert) {
+ upsertTable(df, tableSchema, isCaseSensitive, options)
+ } else {
+ saveTable(df, tableSchema, isCaseSensitive, options)
+ }
case SaveMode.ErrorIfExists =>
throw QueryCompilationErrors.tableOrViewAlreadyExistsError(options.table)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 8112cf1c80ef9..531ff62d09b73 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.jdbc
import java.math.{BigDecimal => JBigDecimal}
import java.nio.charset.StandardCharsets
-import java.sql.{Connection, Date, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException, Time, Timestamp}
+import java.sql.{Connection, Date, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException, Statement, Time, Timestamp}
import java.time.{Instant, LocalDate}
import java.util
@@ -46,7 +46,8 @@ import org.apache.spark.sql.connector.catalog.{Identifier, TableChange}
import org.apache.spark.sql.connector.catalog.index.{SupportsIndex, TableIndex}
import org.apache.spark.sql.connector.expressions.NamedReference
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
-import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType, NoopDialect}
+import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions.JDBC_TABLE_NAME
+import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType, MergeByTempTable, NoopDialect}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.unsafe.types.UTF8String
@@ -109,19 +110,16 @@ object JdbcUtils extends Logging with SQLConfHelper {
JdbcDialects.get(url).isCascadingTruncateTable()
}
- /**
- * Returns an Insert SQL statement for inserting a row into the target table via JDBC conn.
- */
- def getInsertStatement(
+ def getColumns(
table: String,
rddSchema: StructType,
tableSchema: Option[StructType],
isCaseSensitive: Boolean,
- dialect: JdbcDialect): String = {
- val columns = if (tableSchema.isEmpty) {
+ dialect: JdbcDialect): Array[StructField] = {
+ if (tableSchema.isEmpty) {
rddSchema.fields
} else {
- // The generated insert statement needs to follow rddSchema's column sequence and
+ // The column sequence needs to follow rddSchema's column sequence and
// tableSchema's column names. When appending data into some case-sensitive DBMSs like
// PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of
// RDD column names for user convenience.
@@ -132,6 +130,18 @@ object JdbcUtils extends Logging with SQLConfHelper {
}
}
}
+ }
+
+ /**
+ * Returns an Insert SQL statement for inserting a row into the target table via JDBC conn.
+ */
+ def getInsertStatement(
+ table: String,
+ rddSchema: StructType,
+ tableSchema: Option[StructType],
+ isCaseSensitive: Boolean,
+ dialect: JdbcDialect): String = {
+ val columns = getColumns(table, rddSchema, tableSchema, isCaseSensitive, dialect)
dialect.insertIntoTable(table, columns)
}
@@ -760,14 +770,16 @@ object JdbcUtils extends Logging with SQLConfHelper {
* updated even with error if it doesn't support transaction, as there're dirty outputs.
*/
def savePartition(
- table: String,
+ connection: Option[Connection],
iterator: Iterator[Row],
rddSchema: StructType,
insertStmt: String,
batchSize: Int,
dialect: JdbcDialect,
isolationLevel: Int,
- options: JDBCOptions): Unit = {
+ options: JDBCOptions,
+ batchExecuted: () => Unit = () => (),
+ finallyCallback: () => Unit = () => ()): Unit = {
if (iterator.isEmpty) {
return
@@ -775,7 +787,8 @@ object JdbcUtils extends Logging with SQLConfHelper {
val outMetrics = TaskContext.get().taskMetrics().outputMetrics
- val conn = dialect.createConnectionFactory(options)(-1)
+ val closeConn = connection.isEmpty
+ val conn = connection.getOrElse(dialect.createConnectionFactory(options)(-1))
var committed = false
var finalIsolationLevel = Connection.TRANSACTION_NONE
@@ -836,11 +849,13 @@ object JdbcUtils extends Logging with SQLConfHelper {
totalRowCount += 1
if (rowCount % batchSize == 0) {
stmt.executeBatch()
+ batchExecuted()
rowCount = 0
}
}
if (rowCount > 0) {
stmt.executeBatch()
+ batchExecuted()
}
} finally {
stmt.close()
@@ -870,6 +885,7 @@ object JdbcUtils extends Logging with SQLConfHelper {
}
throw e
} finally {
+ finallyCallback()
if (!committed) {
// The stage must fail. We got here through an exception path, so
// let the exception through unless rollback() or close() want to
@@ -879,13 +895,17 @@ object JdbcUtils extends Logging with SQLConfHelper {
} else {
outMetrics.setRecordsWritten(totalRowCount)
}
- conn.close()
+ if (closeConn) {
+ conn.close()
+ }
} else {
outMetrics.setRecordsWritten(totalRowCount)
// The stage must succeed. We cannot propagate any exception close() might throw.
try {
- conn.close()
+ if (closeConn) {
+ conn.close()
+ }
} catch {
case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
}
@@ -893,6 +913,59 @@ object JdbcUtils extends Logging with SQLConfHelper {
}
}
+ def upsertPartition(
+ table: String,
+ iterator: Iterator[Row],
+ rddSchema: StructType,
+ tblSchema: Option[StructType],
+ isCaseSensitive: Boolean,
+ batchSize: Int,
+ dialect: JdbcDialect with MergeByTempTable,
+ isolationLevel: Int,
+ options: JdbcOptionsInWrite): Unit = {
+ val conn = dialect.createConnectionFactory(options)(-1)
+ val stmt = conn.createStatement()
+ stmt.setQueryTimeout(options.queryTimeout)
+
+ val tempTable = dialect.createTempTableName()
+ createTable(conn, tempTable, rddSchema, isCaseSensitive, options, temporary = true)
+ val insertStmt = getInsertStatement(tempTable, rddSchema, tblSchema, isCaseSensitive, dialect)
+ val tempParams = options.parameters.updated(JDBC_TABLE_NAME, tempTable)
+ val tempOptions = new JdbcOptionsInWrite(tempParams)
+ val columns = getColumns(table, rddSchema, tblSchema, isCaseSensitive, dialect)
+ val upsert = () => upsertBatch(stmt, table, tempTable, columns,
+ options.upsertKeyColumns, dialect)()
+ val finallyCallback = () => try {
+ dropTable(conn, tempTable, options)
+ } catch {
+ case e: Exception => logWarning(s"Exception dropping temp table $tempTable", e)
+ }
+
+ savePartition(Some(conn), iterator, rddSchema, insertStmt, batchSize, dialect,
+ isolationLevel, tempOptions, upsert, finallyCallback)
+
+ // The stage must succeed. We cannot propagate any exception close() might throw.
+ try {
+ conn.close()
+ } catch {
+ case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
+ }
+ }
+
+ def upsertBatch(
+ stmt: Statement,
+ table: String,
+ tempTable: String,
+ columns: Array[StructField],
+ upsertKeyColumns: Array[String],
+ dialect: JdbcDialect with MergeByTempTable)(): Unit = {
+ // upsert batch into table from tempTable by
+ dialect.merge(stmt, tempTable, table, columns, upsertKeyColumns)
+
+ // truncate tempTable
+ stmt.executeUpdate(dialect.getTruncateQuery(tempTable))
+ }
+
/**
* Compute the schema string for this RDD.
*/
@@ -992,8 +1065,72 @@ object JdbcUtils extends Logging with SQLConfHelper {
case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n)
case _ => df
}
- repartitionedDF.foreachPartition { iterator => savePartition(
- table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel, options)
+ repartitionedDF.foreachPartition { iterator: Iterator[Row] => savePartition(
+ None, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel, options)
+ }
+ }
+
+ def upsertTable(
+ df: DataFrame,
+ tableSchema: Option[StructType],
+ isCaseSensitive: Boolean,
+ options: JdbcOptionsInWrite): Unit = {
+ val url = options.url
+ val table = options.table
+ val dialect = JdbcDialects.get(url)
+ val rddSchema = df.schema
+ val batchSize = options.batchSize
+ val isolationLevel = options.isolationLevel
+
+ if (!dialect.isInstanceOf[MergeByTempTable]) {
+ throw QueryCompilationErrors.tableDoesNotSupportUpsertError(
+ options.table, dialect.getClass.getSimpleName)
+ }
+ val dialectWithMerge = dialect.asInstanceOf[JdbcDialect with MergeByTempTable]
+
+ if (options.upsertKeyColumns.isEmpty) {
+ throw QueryCompilationErrors.upsertKeyColumnsRequiredError()
+ }
+
+ val columns = getColumns(table, rddSchema, tableSchema, isCaseSensitive, dialect)
+ if (columns.forall(col => options.upsertKeyColumns.contains(col.name))) {
+ throw QueryCompilationErrors.upsertNotAllowedError(options.table,
+ "table has only key columns")
+ }
+
+ val repartitionedDF = options.numPartitions match {
+ case Some(n) if n <= 0 => throw QueryExecutionErrors.invalidJdbcNumPartitionsError(
+ n, JDBCOptions.JDBC_NUM_PARTITIONS)
+ case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n)
+ case _ => df
+ }
+ repartitionedDF.foreachPartition { iterator: Iterator[Row] => upsertPartition(
+ table, iterator, rddSchema, tableSchema, isCaseSensitive, batchSize, dialectWithMerge,
+ isolationLevel, options)
+ }
+ }
+
+ private def getMergeSchemaAndDialect(
+ schema: StructType,
+ dialect: JdbcDialect,
+ options: JdbcOptionsInWrite): (StructType, Option[JdbcDialect with MergeByTempTable]) = {
+ if (options.isUpsert) {
+ if (!dialect.isInstanceOf[MergeByTempTable]) {
+ throw QueryCompilationErrors.tableDoesNotSupportUpsertError(
+ options.table, dialect.getClass.getSimpleName)
+ }
+ val mergeDialect = dialect.asInstanceOf[JdbcDialect with MergeByTempTable]
+
+ // upsert requires a primary index on the upsert key columns
+ // upsert key columns have to be not-nullable to support a primary index
+ val mergeSchema = StructType(schema.fields.map {
+ case field if options.upsertKeyColumns.contains(field.name) => field.copy(nullable = false)
+ case field => field
+ }.toSeq)
+
+ (mergeSchema, Some(mergeDialect))
+ } else {
+ (schema, None)
}
}
@@ -1003,16 +1140,30 @@ object JdbcUtils extends Logging with SQLConfHelper {
def createTable(
conn: Connection,
tableName: String,
- schema: StructType,
+ tableSchema: StructType,
caseSensitive: Boolean,
- options: JdbcOptionsInWrite): Unit = {
+ options: JdbcOptionsInWrite,
+ temporary: Boolean = false): Unit = {
val statement = conn.createStatement
val dialect = JdbcDialects.get(options.url)
+
+ // in presence of upsert mode, we need to modify schema and access the MergeByTempTable dialect
+ val (schema, mergeDialect) = getMergeSchemaAndDialect(tableSchema, dialect, options)
+
val strSchema = schemaString(
dialect, schema, caseSensitive, options.createTableColumnTypes)
try {
statement.setQueryTimeout(options.queryTimeout)
- dialect.createTable(statement, tableName, strSchema, options)
+ if (temporary) {
+ dialect.createTempTable(statement, tableName, strSchema, options)
+ } else {
+ dialect.createTable(statement, tableName, strSchema, options)
+ }
+ if (options.isUpsert) {
+ // creating a table that is going to be upsert requires a primary index
+ assert(mergeDialect.isDefined)
+ mergeDialect.foreach(_.createPrimaryIndex(statement, tableName, options.upsertKeyColumns))
+ }
if (options.tableComment.nonEmpty) {
try {
val tableCommentQuery = dialect.getTableCommentQuery(tableName, options.tableComment)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
index 82f6f5c6264c4..91c43a4b049b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.jdbc
-import java.sql.{Connection, SQLException, Types}
+import java.sql.{Connection, SQLException, Statement, Types}
import java.util
import java.util.Locale
import java.util.concurrent.ConcurrentHashMap
@@ -33,10 +33,11 @@ import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.connector.catalog.index.TableIndex
import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NamedReference}
-import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
+import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcOptionsInWrite, JdbcUtils}
import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DecimalType, MetadataBuilder, ShortType, StringType, TimestampType}
-private[sql] case class H2Dialect() extends JdbcDialect with NoLegacyJDBCError {
+private[sql] case class H2Dialect() extends JdbcDialect with MergeByTempTable
+ with NoLegacyJDBCError {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2")
@@ -96,6 +97,14 @@ private[sql] case class H2Dialect() extends JdbcDialect with NoLegacyJDBCError {
functionMap.clear()
}
+ override def createTempTable(
+ statement: Statement,
+ tableName: String,
+ strSchema: String,
+ options: JdbcOptionsInWrite): Unit = {
+ statement.executeUpdate(s"CREATE LOCAL TEMPORARY TABLE $tableName ($strSchema)")
+ }
+
// CREATE INDEX syntax
// https://www.h2database.com/html/commands.html#create_index
override def createIndex(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 0d10b8e04484e..503c4cfbff221 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -286,6 +286,21 @@ abstract class JdbcDialect extends Serializable with Logging {
s"INSERT INTO $table ($columns) VALUES ($placeholders)"
}
+ /**
+ * Create the temporary table if the table does not exist.
+ * @param statement
+ * @param tableName
+ * @param strSchema
+ * @param options
+ */
+ def createTempTable(
+ statement: Statement,
+ tableName: String,
+ strSchema: String,
+ options: JdbcOptionsInWrite): Unit = {
+ statement.executeUpdate(s"CREATE TEMPORARY TABLE $tableName ($strSchema)")
+ }
+
/**
* Get the SQL query that should be used to find if the given table exists. Dialects can
* override this method to return a query that works best in a particular database.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MergeByTempTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MergeByTempTable.scala
new file mode 100644
index 0000000000000..42f40972e0eff
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MergeByTempTable.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.sql.Statement
+import java.util.UUID
+
+import org.apache.spark.sql.types.StructField
+
+
+trait MergeByTempTable {
+ self: JdbcDialect =>
+
+ def createTempTableName(): String =
+ "temp" + UUID.randomUUID().toString.replaceAll("-", "")
+
+ def getCreatePrimaryIndex(tableName: String, columns: Array[String]): String = {
+ val indexColumns = columns.map(quoteIdentifier).mkString(", ")
+ s"ALTER TABLE $tableName ADD PRIMARY KEY ($indexColumns)"
+ }
+
+ def createPrimaryIndex(
+ stmt: Statement,
+ tableName: String,
+ indexColumns: Array[String]): Unit = {
+ val sql = getCreatePrimaryIndex(tableName, indexColumns)
+ stmt.executeUpdate(sql)
+ }
+
+ /**
+ * Returns a SQL query that merges `sourceTableName` into `destinationTableName`
+ * w.r.t. to the `keyColumns`.
+ *
+ * Table names `destinationTableName` and `sourceTableName`, as well as columns in `columns`
+ * are expected to be quoted by `JdbcDialect.quoteIdentifier`.
+ *
+ * @param sourceTableName
+ * @param destinationTableName
+ * @param columns
+ * @param keyColumns
+ * @return sql query
+ */
+ def getMergeQuery(
+ sourceTableName: String,
+ destinationTableName: String,
+ columns: Array[StructField],
+ keyColumns: Array[String]): String = {
+ val indexColumns = keyColumns.map(quoteIdentifier)
+ val mergeCondition = indexColumns.map(k => s"dst.$k = src.$k").mkString(" AND ")
+ val updateClause = columns.filterNot(col => keyColumns.contains(col.name))
+ .map(col => quoteIdentifier(col.name))
+ .map(col => s"$col = src.$col")
+ .mkString(", ")
+ val quotedColumns = columns.map(col => quoteIdentifier(col.name))
+ val insertColumns = quotedColumns.mkString(", ")
+ val insertValues = quotedColumns.map(k => s"src.$k").mkString(", ")
+
+ s"""
+ |MERGE INTO $destinationTableName AS dst
+ | USING $sourceTableName AS src
+ | ON ($mergeCondition)
+ | WHEN MATCHED THEN
+ | UPDATE SET $updateClause
+ | WHEN NOT MATCHED THEN
+ | INSERT ($insertColumns) VALUES ($insertValues);
+ |""".stripMargin
+ }
+
+ /**
+ * Merges table `sourceTableName` into `destinationTableName` w.r.t. to the `keyColumns`.
+ *
+ * Table names `destinationTableName` and `sourceTableName`, as well as columns in `columns`
+ * are expected to be quoted by `JdbcDialect.quoteIdentifier`.
+ *
+ * @param stmt
+ * @param sourceTableName
+ * @param destinationTableName
+ * @param columns
+ * @param keyColumns
+ */
+ def merge(
+ stmt: Statement,
+ sourceTableName: String,
+ destinationTableName: String,
+ columns: Array[StructField],
+ keyColumns: Array[String]): Int = {
+ val sql = getMergeQuery(sourceTableName, destinationTableName, columns, keyColumns)
+ stmt.executeUpdate(sql)
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
index 531e5d4f0f3a9..76e12d79a0637 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.jdbc
-import java.sql.SQLException
+import java.sql.{SQLException, Statement}
import java.util.Locale
import scala.util.control.NonFatal
@@ -28,13 +28,14 @@ import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.expressions.{Expression, NullOrdering, SortDirection}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
-import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
+import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcOptionsInWrite}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.jdbc.MsSqlServerDialect.{GEOGRAPHY, GEOMETRY}
import org.apache.spark.sql.types._
-private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCError {
+private case class MsSqlServerDialect() extends JdbcDialect with MergeByTempTable
+ with NoLegacyJDBCError {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:sqlserver")
@@ -270,6 +271,23 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr
new MsSqlServerSQLQueryBuilder(this, options)
override def supportsLimit: Boolean = true
+
+ override def createTempTableName(): String = "##" + super.createTempTableName()
+
+ override def createTempTable(
+ statement: Statement,
+ tableName: String,
+ strSchema: String,
+ options: JdbcOptionsInWrite): Unit = {
+ // MsSqlServer does not have a temp table specific syntax
+ super.createTable(statement, tableName, strSchema, options)
+ }
+
+ override def getCreatePrimaryIndex(tableName: String, columns: Array[String]): String = {
+ val indexColumns = columns.map(quoteIdentifier).mkString(", ")
+ s"ALTER TABLE $tableName ADD PRIMARY KEY CLUSTERED ($indexColumns)"
+ }
+
}
private object MsSqlServerDialect {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
index 851b0e04d5e5c..0f6e834001dc8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.jdbc
-import java.sql.{Date, SQLException, Timestamp, Types}
+import java.sql.{Date, SQLException, Statement, Timestamp, Types}
import java.util.Locale
import scala.util.control.NonFatal
@@ -26,12 +26,13 @@ import org.apache.spark.{SparkThrowable, SparkUnsupportedOperationException}
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.connector.expressions.{Expression, Literal}
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
+import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcOptionsInWrite}
import org.apache.spark.sql.jdbc.OracleDialect._
import org.apache.spark.sql.types._
-private case class OracleDialect() extends JdbcDialect with SQLConfHelper with NoLegacyJDBCError {
+private case class OracleDialect()
+ extends JdbcDialect with MergeByTempTable with SQLConfHelper with NoLegacyJDBCError {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:oracle")
@@ -276,6 +277,27 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N
case _ => super.classifyException(e, condition, messageParameters, description, isRuntime)
}
}
+
+ override def createTempTable(
+ statement: Statement,
+ tableName: String,
+ strSchema: String,
+ options: JdbcOptionsInWrite): Unit = {
+ statement.executeUpdate(s"CREATE GLOBAL TEMPORARY TABLE $tableName ($strSchema) " +
+ s"ON COMMIT DELETE ROWS")
+ }
+
+ override def getMergeQuery(
+ sourceTableName: String,
+ destinationTableName: String,
+ columns: Array[StructField],
+ keyColumns: Array[String]): String = {
+ // Oracle dialect does not like a few bits of the standard SQL MERGE command
+ super.getMergeQuery(sourceTableName, destinationTableName, columns, keyColumns)
+ .replace(" AS dst\n", " dst\n")
+ .replace(" AS src\n", " src\n")
+ .replace(");\n", ")\n")
+ }
}
private[jdbc] object OracleDialect {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
index b4cd5f578ccd1..1cf4cceb14b33 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
@@ -39,7 +39,7 @@ import org.apache.spark.sql.types._
private case class PostgresDialect()
- extends JdbcDialect with SQLConfHelper with NoLegacyJDBCError {
+ extends JdbcDialect with MergeByTempTable with SQLConfHelper with NoLegacyJDBCError {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:postgresql")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 6896f6993fb33..007ce3c3bc295 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -78,14 +78,11 @@ class JDBCSuite extends QueryTest with SharedSparkSession {
}
}
- object JdbcClientTypes {
- val INTEGER = "INTEGER"
- val STRING = "CHARACTER VARYING"
+ val testMergeByTempTableDialect = new JdbcDialect with MergeByTempTable {
+ override def canHandle(url: String): Boolean = url.startsWith("jdbc:merge-by-temp")
}
- def defaultMetadata(
- dataType: DataType,
- jdbcClientType: String): Metadata = new MetadataBuilder()
+ def defaultMetadata(dataType: DataType): Metadata = new MetadataBuilder()
.putLong("scale", 0)
.putBoolean("isTimestampNTZ", false)
.putBoolean("isSigned", dataType.isInstanceOf[NumericType])
@@ -1340,6 +1337,59 @@ class JDBCSuite extends QueryTest with SharedSparkSession {
}
}
+ test("MergeByTempTable: Create temp table name") {
+ val tmp1 = testMergeByTempTableDialect.createTempTableName()
+ assert(tmp1.nonEmpty)
+ val tmp2 = testMergeByTempTableDialect.createTempTableName()
+ assert(tmp2.nonEmpty)
+ assert(tmp2 !== tmp1)
+ }
+
+ test("MergeByTempTable: Create temp table name - MsSqlServer") {
+ val msSqlServerDialect = JdbcDialects.get("jdbc:sqlserver")
+ assert(msSqlServerDialect.isInstanceOf[MergeByTempTable])
+ val upsert = msSqlServerDialect.asInstanceOf[MergeByTempTable]
+
+ val tmp = upsert.createTempTableName()
+ assert(tmp.startsWith("##"))
+ }
+
+ test("MergeByTempTable: Create primary index") {
+ val sql = testMergeByTempTableDialect.getCreatePrimaryIndex("test", Array("id", "ts"))
+ assert(sql === """ALTER TABLE test ADD PRIMARY KEY ("id", "ts")""")
+ }
+
+ test("MergeByTempTable: Create primary index - MsSqlServer") {
+ val msSqlServerDialect = JdbcDialects.get("jdbc:sqlserver")
+ assert(msSqlServerDialect.isInstanceOf[MergeByTempTable])
+ val upsert = msSqlServerDialect.asInstanceOf[MergeByTempTable]
+
+ val sql = upsert.getCreatePrimaryIndex("test", Array("id", "ts"))
+ assert(sql === """ALTER TABLE test ADD PRIMARY KEY CLUSTERED ("id", "ts")""")
+ }
+
+ test("MergeByTempTable: MERGE table into table") {
+ val columns = Array(
+ StructField("id", LongType),
+ StructField("ts", TimestampType),
+ StructField("v1", StringType),
+ StructField("v2", IntegerType)
+ )
+ val keyColumns = Array("id", "ts")
+ val sql = testMergeByTempTableDialect.getMergeQuery(
+ "source", "destination", columns, keyColumns)
+ assert(sql ===
+ """
+ |MERGE INTO destination AS dst
+ | USING source AS src
+ | ON (dst."id" = src."id" AND dst."ts" = src."ts")
+ | WHEN MATCHED THEN
+ | UPDATE SET "v1" = src."v1", "v2" = src."v2"
+ | WHEN NOT MATCHED THEN
+ | INSERT ("id", "ts", "v1", "v2") VALUES (src."id", src."ts", src."v1", src."v2");
+ |""".stripMargin)
+ }
+
test("SPARK 12941: The data type mapping for StringType to Oracle") {
val oracleDialect = JdbcDialects.get("jdbc:oracle://127.0.0.1/db")
assert(oracleDialect.getJDBCType(StringType).
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index 535b8257ad94a..709c99f1a6f94 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
+import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
@@ -207,6 +208,32 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter {
}
}
+ test("Upsert") {
+ JdbcDialects.registerDialect(H2Dialect())
+ val table = "upsert"
+ spark
+ .range(10)
+ .select(col("id"), col("id").as("val"))
+ .write
+ .jdbc(url, table, new Properties())
+ spark
+ .range(5, 15, 1, 10)
+ .withColumn("val", lit(-1))
+ .write
+ .options(Map("upsert" -> "true", "upsertKeyColumns" -> "id"))
+ .mode(SaveMode.Append)
+ .jdbc(url, table, new Properties())
+ val result = spark.read
+ .jdbc(url, table, new Properties())
+ .select((col("val") === -1).as("updated"))
+ .groupBy(col("updated"))
+ .count()
+ .sort(col("updated"))
+ .collect()
+ // we expect 5 unchanged rows (ids 0..4) and 10 updated rows (ids 5..14)
+ assert(result === Seq(Row(false, 5), Row(true, 10)))
+ }
+
test("Truncate") {
JdbcDialects.unregisterDialect(H2Dialect())
try {