diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
index 8bceb9506e850..5d98d17097a35 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
@@ -37,7 +37,7 @@ import org.apache.spark.tags.DockerTest
* }}}
*/
@DockerTest
-class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
+class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTests {
override val db = new DatabaseOnDocker {
override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME",
"mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04")
@@ -150,8 +150,18 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
"""
|INSERT INTO bits VALUES (1, 2, 1)
""".stripMargin).executeUpdate()
+
+ conn.prepareStatement("CREATE TABLE upsert (id INT, ts DATETIME, v1 FLOAT, v2 FLOAT, " +
+ "CONSTRAINT pk_upsert PRIMARY KEY (id, ts))").executeUpdate()
+ conn.prepareStatement("INSERT INTO upsert VALUES " +
+ "(1, '1996-01-01 01:23:45', 1.234, 1.234567), " +
+ "(1, '1996-01-01 01:23:46', 1.235, 1.234568), " +
+ "(2, '1996-01-01 01:23:45', 2.345, 2.345678), " +
+ "(2, '1996-01-01 01:23:46', 2.346, 2.345679)").executeUpdate()
}
+ override val createTableOption = "; ALTER TABLE new_upsert_table ADD PRIMARY KEY (id, ts)"
+
test("Basic test") {
val df = spark.read.jdbc(jdbcUrl, "tbl", new Properties)
val rows = df.collect()
@@ -437,4 +447,5 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
.load()
assert(df.collect().toSet === expectedResult)
}
+
}
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
index 684cec37c1703..97bfeafcb9e7c 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
@@ -39,7 +39,7 @@ import org.apache.spark.tags.DockerTest
* }}}
*/
@DockerTest
-class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
+class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTests {
override val db = new MySQLDatabaseOnDocker
override def dataPreparation(conn: Connection): Unit = {
@@ -97,8 +97,18 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
conn.prepareStatement("CREATE TABLE TBL_GEOMETRY (col0 GEOMETRY)").executeUpdate()
conn.prepareStatement("INSERT INTO TBL_GEOMETRY VALUES (ST_GeomFromText('POINT(0 0)'))")
.executeUpdate()
+
+ conn.prepareStatement("CREATE TABLE upsert (id INTEGER, ts TIMESTAMP, v1 DOUBLE, v2 DOUBLE, " +
+ "PRIMARY KEY pk (id, ts))").executeUpdate()
+ conn.prepareStatement("INSERT INTO upsert VALUES " +
+ "(1, '1996-01-01 01:23:45', 1.234, 1.234567), " +
+ "(1, '1996-01-01 01:23:46', 1.235, 1.234568), " +
+ "(2, '1996-01-01 01:23:45', 2.345, 2.345678), " +
+ "(2, '1996-01-01 01:23:46', 2.346, 2.345679)").executeUpdate()
}
+ override val createTableOption = "; ALTER TABLE new_upsert_table ADD PRIMARY KEY (id, ts)"
+
def testConnection(): Unit = {
Using.resource(getConnection()) { conn =>
assert(conn.getClass.getName === "com.mysql.cj.jdbc.ConnectionImpl")
@@ -285,7 +295,6 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(sql("select x, y from queryOption").collect().toSet == expectedResult)
}
-
test("SPARK-47478: all boolean synonyms read-write roundtrip") {
val df = sqlContext.read.jdbc(jdbcUrl, "bools", new Properties)
checkAnswer(df, Row(true, true, true))
@@ -358,6 +367,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
val df = spark.read.jdbc(jdbcUrl, "smallint_round_trip", new Properties)
assert(df.schema.fields.head.dataType === ShortType)
}
+
}
@@ -375,7 +385,7 @@ class MySQLOverMariaConnectorIntegrationSuite extends MySQLIntegrationSuite {
override val db = new MySQLDatabaseOnDocker {
override def getJdbcUrl(ip: String, port: Int): String =
s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass&allowPublicKeyRetrieval=true" +
- s"&useSSL=false"
+ s"&useSSL=false&allowMultiQueries=true"
}
override def testConnection(): Unit = {
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
index 8c0a7c0a809f0..1738ff56f896e 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
@@ -38,7 +38,7 @@ import org.apache.spark.tags.DockerTest
* }}}
*/
@DockerTest
-class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
+class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTests {
override val db = new DatabaseOnDocker {
override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.2-alpine")
override val env = Map(
@@ -190,8 +190,18 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
conn.prepareStatement("CREATE DOMAIN myint AS integer CHECK (VALUE > 0)").executeUpdate()
conn.prepareStatement("CREATE TABLE domain_table (c1 myint)").executeUpdate()
conn.prepareStatement("INSERT INTO domain_table VALUES (1)").executeUpdate()
+
+ conn.prepareStatement("CREATE TABLE upsert (id integer, ts timestamp, v1 double precision, " +
+ "v2 double precision, CONSTRAINT pk PRIMARY KEY (id, ts))").executeUpdate()
+ conn.prepareStatement("INSERT INTO upsert VALUES " +
+ "(1, '1996-01-01 01:23:45', 1.234, 1.234567), " +
+ "(1, '1996-01-01 01:23:46', 1.235, 1.234568), " +
+ "(2, '1996-01-01 01:23:45', 2.345, 2.345678), " +
+ "(2, '1996-01-01 01:23:46', 2.346, 2.345679)").executeUpdate()
}
+ override val createTableOption = "; ALTER TABLE new_upsert_table ADD PRIMARY KEY (id, ts)"
+
test("Type mapping for various types") {
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
val rows = df.collect().sortBy(_.toString())
@@ -321,6 +331,16 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
.collect()
rows.head.toSeq.tail.foreach(c => assert(c.isInstanceOf[java.sql.Timestamp]))
}
+
+ test("SPARK-20557: column type TIMESTAMP with TIME ZONE and TIME with TIME ZONE " +
+ "should be recognized") {
+ // When using JDBC to read the columns of TIMESTAMP with TIME ZONE and TIME with TIME ZONE
+ // the actual types are java.sql.Types.TIMESTAMP and java.sql.Types.TIME
+ val dfRead = sqlContext.read.jdbc(jdbcUrl, "ts_with_timezone", new Properties)
+ val rows = dfRead.collect()
+ val types = rows(0).toSeq.map(x => x.getClass.toString)
+ assert(types(1).equals("class java.sql.Timestamp"))
+ assert(types(2).equals("class java.sql.Timestamp"))
}
test("SPARK-22291: Conversion error when transforming array types of " +
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala
new file mode 100644
index 0000000000000..6b9cbfb3bd4cf
--- /dev/null
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.sql.Timestamp
+import java.util.Properties
+
+import org.apache.spark.sql.{Row, SaveMode}
+import org.apache.spark.sql.functions.{col, lit, rand, when}
+
+trait UpsertTests {
+ self: DockerJDBCIntegrationSuite =>
+
+ import testImplicits._
+
+ def createTableOption: String
+ def upsertTestOptions: Map[String, String] = Map("createTableOptions" -> createTableOption)
+
+ test(s"Upsert existing table") { doTestUpsert(tableExists = true) }
+ test(s"Upsert non-existing table") { doTestUpsert(tableExists = false) }
+
+ Seq(
+ Seq("ts", "id", "v1", "v2"),
+ Seq("ts", "v1", "id", "v2"),
+ Seq("ts", "v1", "v2", "id"),
+ Seq("v2", "v1", "ts", "id")
+ ).foreach { columns =>
+ test(s"Upsert with varying column order - ${columns.mkString(",")}") {
+ doTestUpsert(tableExists = true, Some(columns))
+ }
+ }
+
+ def doTestUpsert(tableExists: Boolean, project: Option[Seq[String]] = None): Unit = {
+ val df = Seq(
+ (1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, 1.234568), // row unchanged
+ (2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, 2.345678), // updates v1
+ (2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, 2.345680), // updates v1 and v2
+ (3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, 3.456789) // inserts new row
+ ).toDF("id", "ts", "v1", "v2").repartition(1) // .repartition(10)
+
+ val table = if (tableExists) "upsert" else "new_upsert_table"
+ val options = upsertTestOptions ++ Map(
+ "numPartitions" -> "10",
+ "upsert" -> "true",
+ "upsertKeyColumns" -> "id, ts"
+ )
+ project.map(columns => df.select(columns.map(col): _*)).getOrElse(df)
+ .write.mode(SaveMode.Append).options(options).jdbc(jdbcUrl, table, new Properties)
+
+ val actual = spark.read.jdbc(jdbcUrl, table, new Properties).collect().toSet
+ val existing = if (tableExists) {
+ Set((1, Timestamp.valueOf("1996-01-01 01:23:45"), 1.234, 1.234567))
+ } else {
+ Set.empty
+ }
+ val upsertedRows = Set(
+ (1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, 1.234568),
+ (2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, 2.345678),
+ (2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, 2.345680),
+ (3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, 3.456789)
+ )
+ val expected = (existing ++ upsertedRows).map { case (id, ts, v1, v2) =>
+ Row(Integer.valueOf(id), ts, v1.doubleValue(), v2.doubleValue())
+ }
+ assert(actual === expected)
+ }
+
+ test(s"Upsert concurrency") {
+ // create a table with 100k rows
+ val init =
+ spark.range(100000)
+ .withColumn("ts", lit(Timestamp.valueOf("2023-06-07 12:34:56")))
+ .withColumn("v", rand())
+
+ // upsert 100 batches of 100 rows each
+ // run in 32 tasks
+ val patch =
+ spark.range(100)
+ .join(spark.range(100).select(($"id" * 1000).as("offset")))
+ .repartition(32)
+ .select(
+ ($"id" + $"offset").as("id"),
+ lit(Timestamp.valueOf("2023-06-07 12:34:56")).as("ts"),
+ lit(-1.0).as("v")
+ )
+
+ spark.sparkContext.setJobDescription("init")
+ init
+ .write
+ .mode(SaveMode.Overwrite)
+ .option("createTableOptions", createTableOption)
+ .jdbc(jdbcUrl, "new_upsert_table", new Properties)
+
+ spark.sparkContext.setJobDescription("patch")
+ patch
+ .write
+ .mode(SaveMode.Append)
+ .option("upsert", value = true)
+ .option("upsertKeyColumns", "id, ts")
+ .options(upsertTestOptions)
+ .jdbc(jdbcUrl, "new_upsert_table", new Properties)
+
+ // check result table has 100*100 updated rows
+ val result = spark.read.jdbc(jdbcUrl, "new_upsert_table", new Properties)
+ .select($"id", when($"v" === -1.0, true).otherwise(false).as("patched"))
+ .groupBy($"patched")
+ .count()
+ .sort($"patched")
+ .as[(Boolean, Long)]
+ .collect()
+ assert(result === Seq((false, 90000), (true, 10000)))
+ }
+
+ test("Upsert with columns that require quotes") {}
+ test("Upsert with table name that requires quotes") {}
+ test("Upsert null values") {}
+
+ test("Write with unspecified mode with upsert") {}
+ test("Write with overwrite mode with upsert") {}
+ test("Write with error-if-exists mode with upsert") {}
+ test("Write with ignore mode with upsert") {}
+}
diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md
index 637efc24113ef..0a0d320ce9965 100644
--- a/docs/sql-data-sources-jdbc.md
+++ b/docs/sql-data-sources-jdbc.md
@@ -261,6 +261,19 @@ logging into the data sources.
write |
+
+ upsert, upsertKeyColumns |
+
+ These options are JDBC writer related options. They describe how to
+ use UPSERT feature for different JDBC dialects. The upsert option is applicable only when SaveMode.Append is enabled.
+ Set upsert to true to enable upsert append mode. The database is queried for the primary index to detect
+ the upsert key columns that are used to identify rows for update. The upsert key columns can be
+ defined via the upsertKeyColumns as a comma-separated list of column names.
+ Be aware that if the input data set has duplicate rows, the upsert operation is
+ non-deterministic, it is documented at the [upsert(merge) wiki:](https://en.wikipedia.org/wiki/Merge_(SQL)).
+ |
+
+
customSchema |
(none) |
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 24c85b76437cc..f48d3f7c75ee9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -1471,36 +1471,40 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
messageParameters = Map("changes" -> changes.toString()))
}
- private def tableDoesNotSupportError(cmd: String, table: Table): Throwable = {
+ private def tableDoesNotSupportError(cmd: String, tableName: String): Throwable = {
new AnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_1121",
messageParameters = Map(
"cmd" -> cmd,
- "table" -> table.name))
+ "table" -> tableName))
}
def tableDoesNotSupportReadsError(table: Table): Throwable = {
- tableDoesNotSupportError("reads", table)
+ tableDoesNotSupportError("reads", table.name)
}
def tableDoesNotSupportWritesError(table: Table): Throwable = {
- tableDoesNotSupportError("writes", table)
+ tableDoesNotSupportError("writes", table.name)
}
def tableDoesNotSupportDeletesError(table: Table): Throwable = {
- tableDoesNotSupportError("deletes", table)
+ tableDoesNotSupportError("deletes", table.name)
}
def tableDoesNotSupportTruncatesError(table: Table): Throwable = {
- tableDoesNotSupportError("truncates", table)
+ tableDoesNotSupportError("truncates", table.name)
+ }
+
+ def tableDoesNotSupportUpsertsError(tableName: String): Throwable = {
+ tableDoesNotSupportError("upserts", tableName)
}
def tableDoesNotSupportPartitionManagementError(table: Table): Throwable = {
- tableDoesNotSupportError("partition management", table)
+ tableDoesNotSupportError("partition management", table.name)
}
def tableDoesNotSupportAtomicPartitionManagementError(table: Table): Throwable = {
- tableDoesNotSupportError("atomic partition management", table)
+ tableDoesNotSupportError("atomic partition management", table.name)
}
def tableIsNotRowLevelOperationTableError(table: Table): Throwable = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index 43db0c6eef114..0b6a02e161244 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -162,6 +162,11 @@ class JDBCOptions(
val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean
val isCascadeTruncate: Option[Boolean] = parameters.get(JDBC_CASCADE_TRUNCATE).map(_.toBoolean)
+ // if to upsert the table in the JDBC database
+ val isUpsert = parameters.getOrElse(JDBC_UPSERT, "false").toBoolean
+ // the columns used to identify update and insert rows in upsert mode
+ val upsertKeyColumns = parameters.getOrElse(JDBC_UPSERT_KEY_COLUMNS, "").split(",").map(_.trim)
+
// the create table option , which can be table_options or partition_options.
// E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
// TODO: to reuse the existing partition parameters for those partition specific options
@@ -296,6 +301,8 @@ object JDBCOptions {
val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize")
val JDBC_TRUNCATE = newOption("truncate")
val JDBC_CASCADE_TRUNCATE = newOption("cascadeTruncate")
+ val JDBC_UPSERT = newOption("upsert")
+ val JDBC_UPSERT_KEY_COLUMNS = newOption("upsertKeyColumns")
val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions")
val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes")
val JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES = newOption("customSchema")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
index d9be1a1e3f674..0808d16749ff8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
@@ -50,6 +50,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
val dialect = JdbcDialects.get(options.url)
val conn = dialect.createConnectionFactory(options)(-1)
try {
+ val upsert = mode == SaveMode.Append && options.isUpsert
val tableExists = JdbcUtils.tableExists(conn, options)
if (tableExists) {
mode match {
@@ -58,17 +59,20 @@ class JdbcRelationProvider extends CreatableRelationProvider
// In this case, we should truncate table and then load.
truncateTable(conn, options)
val tableSchema = JdbcUtils.getSchemaOption(conn, options)
- saveTable(df, tableSchema, isCaseSensitive, options)
+ saveTable(df, tableSchema, isCaseSensitive, upsert, options)
} else {
// Otherwise, do not truncate the table, instead drop and recreate it
dropTable(conn, options.table, options)
createTable(conn, options.table, df.schema, isCaseSensitive, options)
- saveTable(df, Some(df.schema), isCaseSensitive, options)
+ saveTable(df, Some(df.schema), isCaseSensitive, upsert, options)
}
case SaveMode.Append =>
+ if (options.isUpsert && !dialect.supportsUpsert()) {
+ throw QueryCompilationErrors.tableDoesNotSupportUpsertsError(options.table)
+ }
val tableSchema = JdbcUtils.getSchemaOption(conn, options)
- saveTable(df, tableSchema, isCaseSensitive, options)
+ saveTable(df, tableSchema, isCaseSensitive, upsert, options)
case SaveMode.ErrorIfExists =>
throw QueryCompilationErrors.tableOrViewAlreadyExistsError(options.table)
@@ -80,7 +84,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
}
} else {
createTable(conn, options.table, df.schema, isCaseSensitive, options)
- saveTable(df, Some(df.schema), isCaseSensitive, options)
+ saveTable(df, Some(df.schema), isCaseSensitive, upsert, options)
}
} finally {
conn.close()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index c541ec16fc822..ab278c17c977b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -109,31 +109,46 @@ object JdbcUtils extends Logging with SQLConfHelper {
JdbcDialects.get(url).isCascadingTruncateTable()
}
- /**
- * Returns an Insert SQL statement for inserting a row into the target table via JDBC conn.
- */
- def getInsertStatement(
- table: String,
+ protected def getInsertColumns(
rddSchema: StructType,
tableSchema: Option[StructType],
isCaseSensitive: Boolean,
- dialect: JdbcDialect): String = {
- val columns = if (tableSchema.isEmpty) {
+ dialect: JdbcDialect): Array[StructField] = {
+ if (tableSchema.isEmpty) {
rddSchema.fields
} else {
- // The generated insert statement needs to follow rddSchema's column sequence and
- // tableSchema's column names. When appending data into some case-sensitive DBMSs like
- // PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of
- // RDD column names for user convenience.
rddSchema.fields.map { col =>
tableSchema.get.find(f => conf.resolver(f.name, col.name)).getOrElse {
throw QueryCompilationErrors.columnNotFoundInSchemaError(col, tableSchema)
}
}
}
+ }
+
+ /**
+ * Returns an Insert SQL statement for inserting a row into the target table via JDBC conn.
+ */
+ def getInsertStatement(
+ table: String,
+ rddSchema: StructType,
+ tableSchema: Option[StructType],
+ isCaseSensitive: Boolean,
+ dialect: JdbcDialect): String = {
+ val columns = getInsertColumns(rddSchema, tableSchema, isCaseSensitive, dialect)
dialect.insertIntoTable(table, columns)
}
+ def getUpsertStatement(
+ table: String,
+ rddSchema: StructType,
+ tableSchema: Option[StructType],
+ isCaseSensitive: Boolean,
+ dialect: JdbcDialect,
+ options: JDBCOptions): String = {
+ val columns = getInsertColumns(rddSchema, tableSchema, isCaseSensitive, dialect)
+ dialect.getUpsertStatement(table, columns, isCaseSensitive, options)
+ }
+
/**
* Retrieve standard jdbc types.
*
@@ -958,6 +973,7 @@ object JdbcUtils extends Logging with SQLConfHelper {
df: DataFrame,
tableSchema: Option[StructType],
isCaseSensitive: Boolean,
+ upsert: Boolean,
options: JdbcOptionsInWrite): Unit = {
val url = options.url
val table = options.table
@@ -966,7 +982,12 @@ object JdbcUtils extends Logging with SQLConfHelper {
val batchSize = options.batchSize
val isolationLevel = options.isolationLevel
- val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
+ val insertStmt = if (upsert) {
+ getUpsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect, options)
+ } else {
+ getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
+ }
+
val repartitionedDF = options.numPartitions match {
case Some(n) if n <= 0 => throw QueryExecutionErrors.invalidJdbcNumPartitionsError(
n, JDBCOptions.JDBC_NUM_PARTITIONS)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala
index 7449f66ee020f..1b0bf0f78815b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala
@@ -42,7 +42,8 @@ case class JDBCWriteBuilder(schema: StructType, options: JdbcOptionsInWrite) ext
val conn = dialect.createConnectionFactory(options)(-1)
JdbcUtils.truncateTable(conn, options)
}
- JdbcUtils.saveTable(data, Some(schema), SQLConf.get.caseSensitiveAnalysis, options)
+ JdbcUtils.saveTable(
+ data, Some(schema), SQLConf.get.caseSensitiveAnalysis, upsert = false, options)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 5f69d18cad756..32a006a343980 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -335,6 +335,17 @@ abstract class JdbcDialect extends Serializable with Logging {
s"TRUNCATE TABLE $table"
}
+ @Since("3.5.0")
+ def supportsUpsert(): Boolean = false
+
+ @Since("3.5.0")
+ def getUpsertStatement(
+ tableName: String,
+ columns: Array[StructField],
+ isCaseSensitive: Boolean,
+ options: JDBCOptions): String =
+ throw new UnsupportedOperationException("upserts are not supported")
+
/**
* Override connection specific properties to run before a select is made. This is in place to
* allow dialects that need special treatment to optimize behavior.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
index 862e99adc3b0d..d1c73c2100ab2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.expressions.{Expression, NullOrdering, SortDirection}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
+import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.jdbc.MsSqlServerDialect.{GEOGRAPHY, GEOMETRY}
import org.apache.spark.sql.types._
@@ -139,6 +139,54 @@ private case class MsSqlServerDialect() extends JdbcDialect {
case _ => None
}
+ override def supportsUpsert(): Boolean = true
+
+ override def getUpsertStatement(
+ tableName: String,
+ columns: Array[StructField],
+ isCaseSensitive: Boolean,
+ options: JDBCOptions): String = {
+ val insertColumns = columns.map(_.name).map(quoteIdentifier)
+ val inputs = columns
+ .map(_.dataType)
+ .map(t => JdbcUtils.getJdbcType(t, this).databaseTypeDefinition)
+ .zipWithIndex.map {
+ case (t, idx) => s"DECLARE @param$idx $t; SET @param$idx = ?;"
+ }.mkString("\n")
+ val values = columns.indices.map(i => s"@param$i").mkString(", ")
+ val keyColumns = columns.zipWithIndex.filter {
+ case (col, _) => options.upsertKeyColumns.contains(col.name)
+ }
+ val updateColumns = columns.zipWithIndex.filterNot {
+ case (col, _) => options.upsertKeyColumns.contains(col.name)
+ }
+ val whereClause = keyColumns.map {
+ case (key, idx) => s"${quoteIdentifier(key.name)} = @param$idx"
+ }.mkString(" AND ")
+ val updateClause = updateColumns.map {
+ case (col, idx) => s"${quoteIdentifier(col.name)} = @param$idx"
+ }.mkString(", ")
+
+ s"""
+ |$inputs
+ |
+ |INSERT $tableName (${insertColumns.mkString(", ")})
+ |SELECT $values
+ |WHERE NOT EXISTS (
+ | SELECT 1
+ | FROM $tableName WITH (UPDLOCK, SERIALIZABLE)
+ | WHERE $whereClause
+ |)
+ |
+ |IF (@@ROWCOUNT = 0)
+ |BEGIN
+ | UPDATE TOP (1) $tableName
+ | SET $updateClause
+ | WHERE $whereClause
+ |END
+ |""".stripMargin
+ }
+
override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
// scalastyle:off line.size.limit
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
index d98fcdfd0b23f..f954eda39e804 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
@@ -170,6 +170,31 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper {
schemaBuilder.result()
}
+ override def getTableExistsQuery(table: String): String = {
+ s"SELECT 1 FROM $table LIMIT 1"
+ }
+
+ override def supportsUpsert(): Boolean = true
+
+ override def getUpsertStatement(
+ tableName: String,
+ columns: Array[StructField],
+ isCaseSensitive: Boolean,
+ options: JDBCOptions): String = {
+ val insertColumns = columns.map(_.name).map(quoteIdentifier)
+ val placeholders = columns.map(_ => "?").mkString(",")
+ val upsertKeyColumns = options.upsertKeyColumns.map(quoteIdentifier)
+ val updateColumns = insertColumns.filterNot(upsertKeyColumns.contains)
+ val updateClause =
+ updateColumns.map(x => s"$x = VALUES($x)").mkString(", ")
+
+ s"""
+ |INSERT INTO $tableName (${insertColumns.mkString(", ")})
+ |VALUES ( $placeholders )
+ |ON DUPLICATE KEY UPDATE $updateClause
+ |""".stripMargin
+ }
+
override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
// See https://dev.mysql.com/doc/refman/8.0/en/alter-table.html
@@ -307,7 +332,7 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper {
}
} catch {
case _: Exception =>
- logWarning("Cannot retrieved index info.")
+ logWarning("Cannot retrieve index info.")
}
indexMap.values.toArray
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
index e4cd79e3f53c7..b2a359a0a2ac0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
@@ -158,6 +158,32 @@ private case class PostgresDialect() extends JdbcDialect with SQLConfHelper {
case _ => None
}
+ override def getTableExistsQuery(table: String): String = {
+ s"SELECT 1 FROM $table LIMIT 1"
+ }
+
+ override def supportsUpsert(): Boolean = true
+
+ override def getUpsertStatement(
+ tableName: String,
+ columns: Array[StructField],
+ isCaseSensitive: Boolean,
+ options: JDBCOptions): String = {
+ val insertColumns = columns.map(_.name).map(quoteIdentifier)
+ val placeholders = columns.map(_ => "?").mkString(",")
+ val upsertKeyColumns = options.upsertKeyColumns.map(quoteIdentifier)
+ val updateColumns = insertColumns.filterNot(upsertKeyColumns.contains)
+ val updateClause =
+ updateColumns.map(x => s"$x = EXCLUDED.$x").mkString(", ")
+
+ s"""
+ |INSERT INTO $tableName (${insertColumns.mkString(", ")})
+ |VALUES ( $placeholders )
+ |ON CONFLICT (${upsertKeyColumns.mkString(", ")})
+ |DO UPDATE SET $updateClause
+ |""".stripMargin
+ }
+
override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 34c554f7d37e1..2bb705c75c987 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.{DataSourceScanExec, ExtendedMode, Project
import org.apache.spark.sql.execution.command.{ExplainCommand, ShowCreateTableCommand}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRelation, JdbcUtils}
+import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions.JDBC_UPSERT_KEY_COLUMNS
import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
import org.apache.spark.sql.functions.{lit, percentile_approx}
import org.apache.spark.sql.internal.SQLConf
@@ -1156,6 +1157,68 @@ class JDBCSuite extends QueryTest with SharedSparkSession {
assert(db2.getTruncateQuery(table, Some(true)) == db2Query)
}
+ Seq(
+ (JdbcDialects.get("jdbc:mysql://127.0.0.1/db"),
+ """
+ |INSERT INTO table (`id`, `time`, `value`, `comment`)
+ |VALUES ( ?,?,?,? )
+ |ON DUPLICATE KEY UPDATE `value` = VALUES(`value`), `comment` = VALUES(`comment`)
+ |""".stripMargin),
+ (JdbcDialects.get("jdbc:postgresql://127.0.0.1/db"),
+ """
+ |INSERT INTO table ("id", "time", "value", "comment")
+ |VALUES ( ?,?,?,? )
+ |ON CONFLICT ("id", "time")
+ |DO UPDATE SET "value" = EXCLUDED."value", "comment" = EXCLUDED."comment"
+ |""".stripMargin),
+ (JdbcDialects.get("jdbc:sqlserver://localhost/db"),
+ """
+ |DECLARE @param0 BIGINT; SET @param0 = ?;
+ |DECLARE @param1 DATETIME; SET @param1 = ?;
+ |DECLARE @param2 DOUBLE PRECISION; SET @param2 = ?;
+ |DECLARE @param3 NVARCHAR(MAX); SET @param3 = ?;
+ |
+ |INSERT table ("id", "time", "value", "comment")
+ |SELECT @param0, @param1, @param2, @param3
+ |WHERE NOT EXISTS (
+ | SELECT 1
+ | FROM table WITH (UPDLOCK, SERIALIZABLE)
+ | WHERE "id" = @param0 AND "time" = @param1
+ |)
+ |
+ |IF (@@ROWCOUNT = 0)
+ |BEGIN
+ | UPDATE TOP (1) table
+ | SET "value" = @param2, "comment" = @param3
+ | WHERE "id" = @param0 AND "time" = @param1
+ |END
+ |""".stripMargin)
+ ).foreach { case (dialect, expected) =>
+ test(s"upsert table query by dialect - ${dialect.getClass.getSimpleName.stripSuffix("$")}") {
+ assert(dialect.supportsUpsert() === true)
+
+ val options = {
+ new JDBCOptions(Map(
+ JDBC_UPSERT_KEY_COLUMNS -> "id, time",
+ JDBCOptions.JDBC_URL -> url,
+ JDBCOptions.JDBC_TABLE_NAME -> "table"
+ ))
+ }
+
+ val table = "table"
+ val columns = Array(
+ StructField("id", LongType),
+ StructField("time", TimestampType),
+ StructField("value", DoubleType),
+ StructField("comment", StringType)
+ )
+ val isCaseSensitive = false
+ val stmt = dialect.getUpsertStatement(table, columns, isCaseSensitive, options)
+
+ assert(stmt === expected)
+ }
+ }
+
test("Test DataFrame.where for Date and Timestamp") {
// Regression test for bug SPARK-11788
val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543");