diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index 8bceb9506e850..5d98d17097a35 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -37,7 +37,7 @@ import org.apache.spark.tags.DockerTest * }}} */ @DockerTest -class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { +class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTests { override val db = new DatabaseOnDocker { override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME", "mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04") @@ -150,8 +150,18 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { """ |INSERT INTO bits VALUES (1, 2, 1) """.stripMargin).executeUpdate() + + conn.prepareStatement("CREATE TABLE upsert (id INT, ts DATETIME, v1 FLOAT, v2 FLOAT, " + + "CONSTRAINT pk_upsert PRIMARY KEY (id, ts))").executeUpdate() + conn.prepareStatement("INSERT INTO upsert VALUES " + + "(1, '1996-01-01 01:23:45', 1.234, 1.234567), " + + "(1, '1996-01-01 01:23:46', 1.235, 1.234568), " + + "(2, '1996-01-01 01:23:45', 2.345, 2.345678), " + + "(2, '1996-01-01 01:23:46', 2.346, 2.345679)").executeUpdate() } + override val createTableOption = "; ALTER TABLE new_upsert_table ADD PRIMARY KEY (id, ts)" + test("Basic test") { val df = spark.read.jdbc(jdbcUrl, "tbl", new Properties) val rows = df.collect() @@ -437,4 +447,5 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { .load() assert(df.collect().toSet === expectedResult) } + } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index 684cec37c1703..97bfeafcb9e7c 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -39,7 +39,7 @@ import org.apache.spark.tags.DockerTest * }}} */ @DockerTest -class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { +class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTests { override val db = new MySQLDatabaseOnDocker override def dataPreparation(conn: Connection): Unit = { @@ -97,8 +97,18 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { conn.prepareStatement("CREATE TABLE TBL_GEOMETRY (col0 GEOMETRY)").executeUpdate() conn.prepareStatement("INSERT INTO TBL_GEOMETRY VALUES (ST_GeomFromText('POINT(0 0)'))") .executeUpdate() + + conn.prepareStatement("CREATE TABLE upsert (id INTEGER, ts TIMESTAMP, v1 DOUBLE, v2 DOUBLE, " + + "PRIMARY KEY pk (id, ts))").executeUpdate() + conn.prepareStatement("INSERT INTO upsert VALUES " + + "(1, '1996-01-01 01:23:45', 1.234, 1.234567), " + + "(1, '1996-01-01 01:23:46', 1.235, 1.234568), " + + "(2, '1996-01-01 01:23:45', 2.345, 2.345678), " + + "(2, '1996-01-01 01:23:46', 2.346, 2.345679)").executeUpdate() } + override val createTableOption = "; ALTER TABLE new_upsert_table ADD PRIMARY KEY (id, ts)" + def testConnection(): Unit = { Using.resource(getConnection()) { conn => assert(conn.getClass.getName === "com.mysql.cj.jdbc.ConnectionImpl") @@ -285,7 +295,6 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { assert(sql("select x, y from queryOption").collect().toSet == expectedResult) } - test("SPARK-47478: all boolean synonyms read-write roundtrip") { val df = sqlContext.read.jdbc(jdbcUrl, "bools", new Properties) checkAnswer(df, Row(true, true, true)) @@ -358,6 +367,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { val df = spark.read.jdbc(jdbcUrl, "smallint_round_trip", new Properties) assert(df.schema.fields.head.dataType === ShortType) } + } @@ -375,7 +385,7 @@ class MySQLOverMariaConnectorIntegrationSuite extends MySQLIntegrationSuite { override val db = new MySQLDatabaseOnDocker { override def getJdbcUrl(ip: String, port: Int): String = s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass&allowPublicKeyRetrieval=true" + - s"&useSSL=false" + s"&useSSL=false&allowMultiQueries=true" } override def testConnection(): Unit = { diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 8c0a7c0a809f0..1738ff56f896e 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -38,7 +38,7 @@ import org.apache.spark.tags.DockerTest * }}} */ @DockerTest -class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { +class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTests { override val db = new DatabaseOnDocker { override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.2-alpine") override val env = Map( @@ -190,8 +190,18 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { conn.prepareStatement("CREATE DOMAIN myint AS integer CHECK (VALUE > 0)").executeUpdate() conn.prepareStatement("CREATE TABLE domain_table (c1 myint)").executeUpdate() conn.prepareStatement("INSERT INTO domain_table VALUES (1)").executeUpdate() + + conn.prepareStatement("CREATE TABLE upsert (id integer, ts timestamp, v1 double precision, " + + "v2 double precision, CONSTRAINT pk PRIMARY KEY (id, ts))").executeUpdate() + conn.prepareStatement("INSERT INTO upsert VALUES " + + "(1, '1996-01-01 01:23:45', 1.234, 1.234567), " + + "(1, '1996-01-01 01:23:46', 1.235, 1.234568), " + + "(2, '1996-01-01 01:23:45', 2.345, 2.345678), " + + "(2, '1996-01-01 01:23:46', 2.346, 2.345679)").executeUpdate() } + override val createTableOption = "; ALTER TABLE new_upsert_table ADD PRIMARY KEY (id, ts)" + test("Type mapping for various types") { val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) val rows = df.collect().sortBy(_.toString()) @@ -321,6 +331,16 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { .collect() rows.head.toSeq.tail.foreach(c => assert(c.isInstanceOf[java.sql.Timestamp])) } + + test("SPARK-20557: column type TIMESTAMP with TIME ZONE and TIME with TIME ZONE " + + "should be recognized") { + // When using JDBC to read the columns of TIMESTAMP with TIME ZONE and TIME with TIME ZONE + // the actual types are java.sql.Types.TIMESTAMP and java.sql.Types.TIME + val dfRead = sqlContext.read.jdbc(jdbcUrl, "ts_with_timezone", new Properties) + val rows = dfRead.collect() + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types(1).equals("class java.sql.Timestamp")) + assert(types(2).equals("class java.sql.Timestamp")) } test("SPARK-22291: Conversion error when transforming array types of " + diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala new file mode 100644 index 0000000000000..6b9cbfb3bd4cf --- /dev/null +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Timestamp +import java.util.Properties + +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.functions.{col, lit, rand, when} + +trait UpsertTests { + self: DockerJDBCIntegrationSuite => + + import testImplicits._ + + def createTableOption: String + def upsertTestOptions: Map[String, String] = Map("createTableOptions" -> createTableOption) + + test(s"Upsert existing table") { doTestUpsert(tableExists = true) } + test(s"Upsert non-existing table") { doTestUpsert(tableExists = false) } + + Seq( + Seq("ts", "id", "v1", "v2"), + Seq("ts", "v1", "id", "v2"), + Seq("ts", "v1", "v2", "id"), + Seq("v2", "v1", "ts", "id") + ).foreach { columns => + test(s"Upsert with varying column order - ${columns.mkString(",")}") { + doTestUpsert(tableExists = true, Some(columns)) + } + } + + def doTestUpsert(tableExists: Boolean, project: Option[Seq[String]] = None): Unit = { + val df = Seq( + (1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, 1.234568), // row unchanged + (2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, 2.345678), // updates v1 + (2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, 2.345680), // updates v1 and v2 + (3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, 3.456789) // inserts new row + ).toDF("id", "ts", "v1", "v2").repartition(1) // .repartition(10) + + val table = if (tableExists) "upsert" else "new_upsert_table" + val options = upsertTestOptions ++ Map( + "numPartitions" -> "10", + "upsert" -> "true", + "upsertKeyColumns" -> "id, ts" + ) + project.map(columns => df.select(columns.map(col): _*)).getOrElse(df) + .write.mode(SaveMode.Append).options(options).jdbc(jdbcUrl, table, new Properties) + + val actual = spark.read.jdbc(jdbcUrl, table, new Properties).collect().toSet + val existing = if (tableExists) { + Set((1, Timestamp.valueOf("1996-01-01 01:23:45"), 1.234, 1.234567)) + } else { + Set.empty + } + val upsertedRows = Set( + (1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, 1.234568), + (2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, 2.345678), + (2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, 2.345680), + (3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, 3.456789) + ) + val expected = (existing ++ upsertedRows).map { case (id, ts, v1, v2) => + Row(Integer.valueOf(id), ts, v1.doubleValue(), v2.doubleValue()) + } + assert(actual === expected) + } + + test(s"Upsert concurrency") { + // create a table with 100k rows + val init = + spark.range(100000) + .withColumn("ts", lit(Timestamp.valueOf("2023-06-07 12:34:56"))) + .withColumn("v", rand()) + + // upsert 100 batches of 100 rows each + // run in 32 tasks + val patch = + spark.range(100) + .join(spark.range(100).select(($"id" * 1000).as("offset"))) + .repartition(32) + .select( + ($"id" + $"offset").as("id"), + lit(Timestamp.valueOf("2023-06-07 12:34:56")).as("ts"), + lit(-1.0).as("v") + ) + + spark.sparkContext.setJobDescription("init") + init + .write + .mode(SaveMode.Overwrite) + .option("createTableOptions", createTableOption) + .jdbc(jdbcUrl, "new_upsert_table", new Properties) + + spark.sparkContext.setJobDescription("patch") + patch + .write + .mode(SaveMode.Append) + .option("upsert", value = true) + .option("upsertKeyColumns", "id, ts") + .options(upsertTestOptions) + .jdbc(jdbcUrl, "new_upsert_table", new Properties) + + // check result table has 100*100 updated rows + val result = spark.read.jdbc(jdbcUrl, "new_upsert_table", new Properties) + .select($"id", when($"v" === -1.0, true).otherwise(false).as("patched")) + .groupBy($"patched") + .count() + .sort($"patched") + .as[(Boolean, Long)] + .collect() + assert(result === Seq((false, 90000), (true, 10000))) + } + + test("Upsert with columns that require quotes") {} + test("Upsert with table name that requires quotes") {} + test("Upsert null values") {} + + test("Write with unspecified mode with upsert") {} + test("Write with overwrite mode with upsert") {} + test("Write with error-if-exists mode with upsert") {} + test("Write with ignore mode with upsert") {} +} diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md index 637efc24113ef..0a0d320ce9965 100644 --- a/docs/sql-data-sources-jdbc.md +++ b/docs/sql-data-sources-jdbc.md @@ -261,6 +261,19 @@ logging into the data sources. write + + upsert, upsertKeyColumns + + These options are JDBC writer related options. They describe how to + use UPSERT feature for different JDBC dialects. The upsert option is applicable only when SaveMode.Append is enabled. + Set upsert to true to enable upsert append mode. The database is queried for the primary index to detect + the upsert key columns that are used to identify rows for update. The upsert key columns can be + defined via the upsertKeyColumns as a comma-separated list of column names. + Be aware that if the input data set has duplicate rows, the upsert operation is + non-deterministic, it is documented at the [upsert(merge) wiki:](https://en.wikipedia.org/wiki/Merge_(SQL)). + + + customSchema (none) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 24c85b76437cc..f48d3f7c75ee9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1471,36 +1471,40 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat messageParameters = Map("changes" -> changes.toString())) } - private def tableDoesNotSupportError(cmd: String, table: Table): Throwable = { + private def tableDoesNotSupportError(cmd: String, tableName: String): Throwable = { new AnalysisException( errorClass = "_LEGACY_ERROR_TEMP_1121", messageParameters = Map( "cmd" -> cmd, - "table" -> table.name)) + "table" -> tableName)) } def tableDoesNotSupportReadsError(table: Table): Throwable = { - tableDoesNotSupportError("reads", table) + tableDoesNotSupportError("reads", table.name) } def tableDoesNotSupportWritesError(table: Table): Throwable = { - tableDoesNotSupportError("writes", table) + tableDoesNotSupportError("writes", table.name) } def tableDoesNotSupportDeletesError(table: Table): Throwable = { - tableDoesNotSupportError("deletes", table) + tableDoesNotSupportError("deletes", table.name) } def tableDoesNotSupportTruncatesError(table: Table): Throwable = { - tableDoesNotSupportError("truncates", table) + tableDoesNotSupportError("truncates", table.name) + } + + def tableDoesNotSupportUpsertsError(tableName: String): Throwable = { + tableDoesNotSupportError("upserts", tableName) } def tableDoesNotSupportPartitionManagementError(table: Table): Throwable = { - tableDoesNotSupportError("partition management", table) + tableDoesNotSupportError("partition management", table.name) } def tableDoesNotSupportAtomicPartitionManagementError(table: Table): Throwable = { - tableDoesNotSupportError("atomic partition management", table) + tableDoesNotSupportError("atomic partition management", table.name) } def tableIsNotRowLevelOperationTableError(table: Table): Throwable = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 43db0c6eef114..0b6a02e161244 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -162,6 +162,11 @@ class JDBCOptions( val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean val isCascadeTruncate: Option[Boolean] = parameters.get(JDBC_CASCADE_TRUNCATE).map(_.toBoolean) + // if to upsert the table in the JDBC database + val isUpsert = parameters.getOrElse(JDBC_UPSERT, "false").toBoolean + // the columns used to identify update and insert rows in upsert mode + val upsertKeyColumns = parameters.getOrElse(JDBC_UPSERT_KEY_COLUMNS, "").split(",").map(_.trim) + // the create table option , which can be table_options or partition_options. // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" // TODO: to reuse the existing partition parameters for those partition specific options @@ -296,6 +301,8 @@ object JDBCOptions { val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize") val JDBC_TRUNCATE = newOption("truncate") val JDBC_CASCADE_TRUNCATE = newOption("cascadeTruncate") + val JDBC_UPSERT = newOption("upsert") + val JDBC_UPSERT_KEY_COLUMNS = newOption("upsertKeyColumns") val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes") val JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES = newOption("customSchema") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index d9be1a1e3f674..0808d16749ff8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -50,6 +50,7 @@ class JdbcRelationProvider extends CreatableRelationProvider val dialect = JdbcDialects.get(options.url) val conn = dialect.createConnectionFactory(options)(-1) try { + val upsert = mode == SaveMode.Append && options.isUpsert val tableExists = JdbcUtils.tableExists(conn, options) if (tableExists) { mode match { @@ -58,17 +59,20 @@ class JdbcRelationProvider extends CreatableRelationProvider // In this case, we should truncate table and then load. truncateTable(conn, options) val tableSchema = JdbcUtils.getSchemaOption(conn, options) - saveTable(df, tableSchema, isCaseSensitive, options) + saveTable(df, tableSchema, isCaseSensitive, upsert, options) } else { // Otherwise, do not truncate the table, instead drop and recreate it dropTable(conn, options.table, options) createTable(conn, options.table, df.schema, isCaseSensitive, options) - saveTable(df, Some(df.schema), isCaseSensitive, options) + saveTable(df, Some(df.schema), isCaseSensitive, upsert, options) } case SaveMode.Append => + if (options.isUpsert && !dialect.supportsUpsert()) { + throw QueryCompilationErrors.tableDoesNotSupportUpsertsError(options.table) + } val tableSchema = JdbcUtils.getSchemaOption(conn, options) - saveTable(df, tableSchema, isCaseSensitive, options) + saveTable(df, tableSchema, isCaseSensitive, upsert, options) case SaveMode.ErrorIfExists => throw QueryCompilationErrors.tableOrViewAlreadyExistsError(options.table) @@ -80,7 +84,7 @@ class JdbcRelationProvider extends CreatableRelationProvider } } else { createTable(conn, options.table, df.schema, isCaseSensitive, options) - saveTable(df, Some(df.schema), isCaseSensitive, options) + saveTable(df, Some(df.schema), isCaseSensitive, upsert, options) } } finally { conn.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index c541ec16fc822..ab278c17c977b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -109,31 +109,46 @@ object JdbcUtils extends Logging with SQLConfHelper { JdbcDialects.get(url).isCascadingTruncateTable() } - /** - * Returns an Insert SQL statement for inserting a row into the target table via JDBC conn. - */ - def getInsertStatement( - table: String, + protected def getInsertColumns( rddSchema: StructType, tableSchema: Option[StructType], isCaseSensitive: Boolean, - dialect: JdbcDialect): String = { - val columns = if (tableSchema.isEmpty) { + dialect: JdbcDialect): Array[StructField] = { + if (tableSchema.isEmpty) { rddSchema.fields } else { - // The generated insert statement needs to follow rddSchema's column sequence and - // tableSchema's column names. When appending data into some case-sensitive DBMSs like - // PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of - // RDD column names for user convenience. rddSchema.fields.map { col => tableSchema.get.find(f => conf.resolver(f.name, col.name)).getOrElse { throw QueryCompilationErrors.columnNotFoundInSchemaError(col, tableSchema) } } } + } + + /** + * Returns an Insert SQL statement for inserting a row into the target table via JDBC conn. + */ + def getInsertStatement( + table: String, + rddSchema: StructType, + tableSchema: Option[StructType], + isCaseSensitive: Boolean, + dialect: JdbcDialect): String = { + val columns = getInsertColumns(rddSchema, tableSchema, isCaseSensitive, dialect) dialect.insertIntoTable(table, columns) } + def getUpsertStatement( + table: String, + rddSchema: StructType, + tableSchema: Option[StructType], + isCaseSensitive: Boolean, + dialect: JdbcDialect, + options: JDBCOptions): String = { + val columns = getInsertColumns(rddSchema, tableSchema, isCaseSensitive, dialect) + dialect.getUpsertStatement(table, columns, isCaseSensitive, options) + } + /** * Retrieve standard jdbc types. * @@ -958,6 +973,7 @@ object JdbcUtils extends Logging with SQLConfHelper { df: DataFrame, tableSchema: Option[StructType], isCaseSensitive: Boolean, + upsert: Boolean, options: JdbcOptionsInWrite): Unit = { val url = options.url val table = options.table @@ -966,7 +982,12 @@ object JdbcUtils extends Logging with SQLConfHelper { val batchSize = options.batchSize val isolationLevel = options.isolationLevel - val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect) + val insertStmt = if (upsert) { + getUpsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect, options) + } else { + getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect) + } + val repartitionedDF = options.numPartitions match { case Some(n) if n <= 0 => throw QueryExecutionErrors.invalidJdbcNumPartitionsError( n, JDBCOptions.JDBC_NUM_PARTITIONS) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala index 7449f66ee020f..1b0bf0f78815b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala @@ -42,7 +42,8 @@ case class JDBCWriteBuilder(schema: StructType, options: JdbcOptionsInWrite) ext val conn = dialect.createConnectionFactory(options)(-1) JdbcUtils.truncateTable(conn, options) } - JdbcUtils.saveTable(data, Some(schema), SQLConf.get.caseSensitiveAnalysis, options) + JdbcUtils.saveTable( + data, Some(schema), SQLConf.get.caseSensitiveAnalysis, upsert = false, options) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 5f69d18cad756..32a006a343980 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -335,6 +335,17 @@ abstract class JdbcDialect extends Serializable with Logging { s"TRUNCATE TABLE $table" } + @Since("3.5.0") + def supportsUpsert(): Boolean = false + + @Since("3.5.0") + def getUpsertStatement( + tableName: String, + columns: Array[StructField], + isCaseSensitive: Boolean, + options: JDBCOptions): String = + throw new UnsupportedOperationException("upserts are not supported") + /** * Override connection specific properties to run before a select is made. This is in place to * allow dialects that need special treatment to optimize behavior. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 862e99adc3b0d..d1c73c2100ab2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.expressions.{Expression, NullOrdering, SortDirection} import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.jdbc.MsSqlServerDialect.{GEOGRAPHY, GEOMETRY} import org.apache.spark.sql.types._ @@ -139,6 +139,54 @@ private case class MsSqlServerDialect() extends JdbcDialect { case _ => None } + override def supportsUpsert(): Boolean = true + + override def getUpsertStatement( + tableName: String, + columns: Array[StructField], + isCaseSensitive: Boolean, + options: JDBCOptions): String = { + val insertColumns = columns.map(_.name).map(quoteIdentifier) + val inputs = columns + .map(_.dataType) + .map(t => JdbcUtils.getJdbcType(t, this).databaseTypeDefinition) + .zipWithIndex.map { + case (t, idx) => s"DECLARE @param$idx $t; SET @param$idx = ?;" + }.mkString("\n") + val values = columns.indices.map(i => s"@param$i").mkString(", ") + val keyColumns = columns.zipWithIndex.filter { + case (col, _) => options.upsertKeyColumns.contains(col.name) + } + val updateColumns = columns.zipWithIndex.filterNot { + case (col, _) => options.upsertKeyColumns.contains(col.name) + } + val whereClause = keyColumns.map { + case (key, idx) => s"${quoteIdentifier(key.name)} = @param$idx" + }.mkString(" AND ") + val updateClause = updateColumns.map { + case (col, idx) => s"${quoteIdentifier(col.name)} = @param$idx" + }.mkString(", ") + + s""" + |$inputs + | + |INSERT $tableName (${insertColumns.mkString(", ")}) + |SELECT $values + |WHERE NOT EXISTS ( + | SELECT 1 + | FROM $tableName WITH (UPDLOCK, SERIALIZABLE) + | WHERE $whereClause + |) + | + |IF (@@ROWCOUNT = 0) + |BEGIN + | UPDATE TOP (1) $tableName + | SET $updateClause + | WHERE $whereClause + |END + |""".stripMargin + } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) // scalastyle:off line.size.limit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index d98fcdfd0b23f..f954eda39e804 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -170,6 +170,31 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper { schemaBuilder.result() } + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } + + override def supportsUpsert(): Boolean = true + + override def getUpsertStatement( + tableName: String, + columns: Array[StructField], + isCaseSensitive: Boolean, + options: JDBCOptions): String = { + val insertColumns = columns.map(_.name).map(quoteIdentifier) + val placeholders = columns.map(_ => "?").mkString(",") + val upsertKeyColumns = options.upsertKeyColumns.map(quoteIdentifier) + val updateColumns = insertColumns.filterNot(upsertKeyColumns.contains) + val updateClause = + updateColumns.map(x => s"$x = VALUES($x)").mkString(", ") + + s""" + |INSERT INTO $tableName (${insertColumns.mkString(", ")}) + |VALUES ( $placeholders ) + |ON DUPLICATE KEY UPDATE $updateClause + |""".stripMargin + } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) // See https://dev.mysql.com/doc/refman/8.0/en/alter-table.html @@ -307,7 +332,7 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper { } } catch { case _: Exception => - logWarning("Cannot retrieved index info.") + logWarning("Cannot retrieve index info.") } indexMap.values.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index e4cd79e3f53c7..b2a359a0a2ac0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -158,6 +158,32 @@ private case class PostgresDialect() extends JdbcDialect with SQLConfHelper { case _ => None } + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } + + override def supportsUpsert(): Boolean = true + + override def getUpsertStatement( + tableName: String, + columns: Array[StructField], + isCaseSensitive: Boolean, + options: JDBCOptions): String = { + val insertColumns = columns.map(_.name).map(quoteIdentifier) + val placeholders = columns.map(_ => "?").mkString(",") + val upsertKeyColumns = options.upsertKeyColumns.map(quoteIdentifier) + val updateColumns = insertColumns.filterNot(upsertKeyColumns.contains) + val updateClause = + updateColumns.map(x => s"$x = EXCLUDED.$x").mkString(", ") + + s""" + |INSERT INTO $tableName (${insertColumns.mkString(", ")}) + |VALUES ( $placeholders ) + |ON CONFLICT (${upsertKeyColumns.mkString(", ")}) + |DO UPDATE SET $updateClause + |""".stripMargin + } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 34c554f7d37e1..2bb705c75c987 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.{DataSourceScanExec, ExtendedMode, Project import org.apache.spark.sql.execution.command.{ExplainCommand, ShowCreateTableCommand} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRelation, JdbcUtils} +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions.JDBC_UPSERT_KEY_COLUMNS import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper import org.apache.spark.sql.functions.{lit, percentile_approx} import org.apache.spark.sql.internal.SQLConf @@ -1156,6 +1157,68 @@ class JDBCSuite extends QueryTest with SharedSparkSession { assert(db2.getTruncateQuery(table, Some(true)) == db2Query) } + Seq( + (JdbcDialects.get("jdbc:mysql://127.0.0.1/db"), + """ + |INSERT INTO table (`id`, `time`, `value`, `comment`) + |VALUES ( ?,?,?,? ) + |ON DUPLICATE KEY UPDATE `value` = VALUES(`value`), `comment` = VALUES(`comment`) + |""".stripMargin), + (JdbcDialects.get("jdbc:postgresql://127.0.0.1/db"), + """ + |INSERT INTO table ("id", "time", "value", "comment") + |VALUES ( ?,?,?,? ) + |ON CONFLICT ("id", "time") + |DO UPDATE SET "value" = EXCLUDED."value", "comment" = EXCLUDED."comment" + |""".stripMargin), + (JdbcDialects.get("jdbc:sqlserver://localhost/db"), + """ + |DECLARE @param0 BIGINT; SET @param0 = ?; + |DECLARE @param1 DATETIME; SET @param1 = ?; + |DECLARE @param2 DOUBLE PRECISION; SET @param2 = ?; + |DECLARE @param3 NVARCHAR(MAX); SET @param3 = ?; + | + |INSERT table ("id", "time", "value", "comment") + |SELECT @param0, @param1, @param2, @param3 + |WHERE NOT EXISTS ( + | SELECT 1 + | FROM table WITH (UPDLOCK, SERIALIZABLE) + | WHERE "id" = @param0 AND "time" = @param1 + |) + | + |IF (@@ROWCOUNT = 0) + |BEGIN + | UPDATE TOP (1) table + | SET "value" = @param2, "comment" = @param3 + | WHERE "id" = @param0 AND "time" = @param1 + |END + |""".stripMargin) + ).foreach { case (dialect, expected) => + test(s"upsert table query by dialect - ${dialect.getClass.getSimpleName.stripSuffix("$")}") { + assert(dialect.supportsUpsert() === true) + + val options = { + new JDBCOptions(Map( + JDBC_UPSERT_KEY_COLUMNS -> "id, time", + JDBCOptions.JDBC_URL -> url, + JDBCOptions.JDBC_TABLE_NAME -> "table" + )) + } + + val table = "table" + val columns = Array( + StructField("id", LongType), + StructField("time", TimestampType), + StructField("value", DoubleType), + StructField("comment", StringType) + ) + val isCaseSensitive = false + val stmt = dialect.getUpsertStatement(table, columns, isCaseSensitive, options) + + assert(stmt === expected) + } + } + test("Test DataFrame.where for Date and Timestamp") { // Regression test for bug SPARK-11788 val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543");