From f2fdaacde3f7dfbd1561e48cf9f985b1019ef584 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Thu, 1 Jun 2023 21:12:33 +0200 Subject: [PATCH 01/12] Implement upsert for MySQL --- .../sql/jdbc/MySQLIntegrationSuite.scala | 51 ++++++++++++++++++- docs/sql-data-sources-jdbc.md | 13 +++++ .../sql/errors/QueryCompilationErrors.scala | 8 +++ .../datasources/jdbc/JDBCOptions.scala | 8 ++- .../jdbc/JdbcRelationProvider.scala | 12 +++-- .../datasources/jdbc/JdbcUtils.scala | 45 +++++++++++----- .../v2/jdbc/JDBCWriteBuilder.scala | 3 +- .../apache/spark/sql/jdbc/JdbcDialects.scala | 11 ++++ .../apache/spark/sql/jdbc/MySQLDialect.scala | 27 +++++++++- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 21 ++++++++ 10 files changed, 178 insertions(+), 21 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index 684cec37c1703..c687893fa47f8 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -25,6 +25,7 @@ import java.util.Properties import scala.util.Using import org.apache.spark.sql.Row +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.ShortType @@ -40,6 +41,8 @@ import org.apache.spark.tags.DockerTest */ @DockerTest class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { + import testImplicits._ + override val db = new MySQLDatabaseOnDocker override def dataPreparation(conn: Connection): Unit = { @@ -97,6 +100,14 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { conn.prepareStatement("CREATE TABLE TBL_GEOMETRY (col0 GEOMETRY)").executeUpdate() conn.prepareStatement("INSERT INTO TBL_GEOMETRY VALUES (ST_GeomFromText('POINT(0 0)'))") .executeUpdate() + + conn.prepareStatement("CREATE TABLE upsert (id INTEGER, ts TIMESTAMP, v1 DOUBLE, v2 DOUBLE, " + + "PRIMARY KEY pk (id, ts))").executeUpdate() + conn.prepareStatement("INSERT INTO upsert VALUES " + + "(1, '1996-01-01 01:23:45', 1.234, 1.234567), " + + "(1, '1996-01-01 01:23:46', 1.235, 1.234568), " + + "(2, '1996-01-01 01:23:45', 2.345, 2.345678), " + + "(2, '1996-01-01 01:23:46', 2.346, 2.345679)").executeUpdate() } def testConnection(): Unit = { @@ -254,7 +265,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) val df2 = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) val df3 = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) - df1.write.jdbc(jdbcUrl, "numberscopy", new Properties) + df1.write.mode(SaveMode.Append).jdbc(jdbcUrl, "numberscopy", new Properties) df2.write.jdbc(jdbcUrl, "datescopy", new Properties) df3.write.jdbc(jdbcUrl, "stringscopy", new Properties) } @@ -285,7 +296,6 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { assert(sql("select x, y from queryOption").collect().toSet == expectedResult) } - test("SPARK-47478: all boolean synonyms read-write roundtrip") { val df = sqlContext.read.jdbc(jdbcUrl, "bools", new Properties) checkAnswer(df, Row(true, true, true)) @@ -358,6 +368,43 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { val df = spark.read.jdbc(jdbcUrl, "smallint_round_trip", new Properties) assert(df.schema.fields.head.dataType === ShortType) } + + Seq(false, true).foreach { exists => + test(s"Upsert ${if (exists) "" else "non-"}existing table") { + val df = Seq( + (1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, 1.234568), // row unchanged + (2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, 2.345678), // updates v1 + (2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, 2.345680), // updates v1 and v2 + (3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, 3.456789) // inserts new row + ).toDF("id", "ts", "v1", "v2").repartition(10) + + val table = if (exists) "upsert" else "new_table" + val options = Map("numPartitions" -> "10", "upsert" -> "true") + df.write.mode(SaveMode.Append).options(options).jdbc(jdbcUrl, table, new Properties) + + val actual = spark.read.jdbc(jdbcUrl, table, new Properties).collect.toSet + val existing = if (exists) { + Set((1, Timestamp.valueOf("1996-01-01 01:23:45"), 1.234, 1.234567)) + } else { + Set.empty + } + val upsertedRows = Set( + (1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, 1.234568), + (2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, 2.345678), + (2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, 2.345680), + (3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, 3.456789) + ) + val expected = (existing ++ upsertedRows).map { case (id, ts, v1, v2) => + Row(Integer.valueOf(id), ts, v1.doubleValue(), v2.doubleValue()) + } + assert(actual === expected) + } + } + + test("Write with unspecified mode with upsert") { } + test("Write with overwrite mode with upsert") { } + test("Write with error-if-exists mode with upsert") { } + test("Write with ignore mode with upsert") { } } diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md index 637efc24113ef..0a0d320ce9965 100644 --- a/docs/sql-data-sources-jdbc.md +++ b/docs/sql-data-sources-jdbc.md @@ -261,6 +261,19 @@ logging into the data sources. write + + upsert, upsertKeyColumns + + These options are JDBC writer related options. They describe how to + use UPSERT feature for different JDBC dialects. The upsert option is applicable only when SaveMode.Append is enabled. + Set upsert to true to enable upsert append mode. The database is queried for the primary index to detect + the upsert key columns that are used to identify rows for update. The upsert key columns can be + defined via the upsertKeyColumns as a comma-separated list of column names. + Be aware that if the input data set has duplicate rows, the upsert operation is + non-deterministic, it is documented at the [upsert(merge) wiki:](https://en.wikipedia.org/wiki/Merge_(SQL)). + + + customSchema (none) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 24c85b76437cc..c11a577270476 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1495,6 +1495,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat tableDoesNotSupportError("truncates", table) } + def tableDoesNotSupportUpsertsError(table: String): Throwable = { + new AnalysisException( + errorClass = "_LEGACY_ERROR_TEMP_1121", + messageParameters = Map( + "cmd" -> "upserts", + "table" -> table)) + } + def tableDoesNotSupportPartitionManagementError(table: Table): Throwable = { tableDoesNotSupportError("partition management", table) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 43db0c6eef114..7379feaa0e24c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -160,8 +160,12 @@ class JDBCOptions( // ------------------------------------------------------------ // if to truncate the table from the JDBC database val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean - val isCascadeTruncate: Option[Boolean] = parameters.get(JDBC_CASCADE_TRUNCATE).map(_.toBoolean) + // if to upsert the table in the JDBC database + val isUpsert = parameters.getOrElse(JDBC_UPSERT, "false").toBoolean + // the columns used to identify update and insert rows in upsert mode + val upsertKeyColumns = parameters.getOrElse(JDBC_UPSERT_KEY_COLUMNS, "").split(",").map(_.trim) + // the create table option , which can be table_options or partition_options. // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" // TODO: to reuse the existing partition parameters for those partition specific options @@ -296,6 +300,8 @@ object JDBCOptions { val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize") val JDBC_TRUNCATE = newOption("truncate") val JDBC_CASCADE_TRUNCATE = newOption("cascadeTruncate") + val JDBC_UPSERT = newOption("upsert") + val JDBC_UPSERT_KEY_COLUMNS = newOption("upsertKeyColumns") val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes") val JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES = newOption("customSchema") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index d9be1a1e3f674..d69d2cc58e077 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -50,6 +50,7 @@ class JdbcRelationProvider extends CreatableRelationProvider val dialect = JdbcDialects.get(options.url) val conn = dialect.createConnectionFactory(options)(-1) try { + val upsert = mode == SaveMode.Append && options.isUpsert val tableExists = JdbcUtils.tableExists(conn, options) if (tableExists) { mode match { @@ -58,17 +59,20 @@ class JdbcRelationProvider extends CreatableRelationProvider // In this case, we should truncate table and then load. truncateTable(conn, options) val tableSchema = JdbcUtils.getSchemaOption(conn, options) - saveTable(df, tableSchema, isCaseSensitive, options) + saveTable(df, tableSchema, isCaseSensitive, upsert, options) } else { // Otherwise, do not truncate the table, instead drop and recreate it dropTable(conn, options.table, options) createTable(conn, options.table, df.schema, isCaseSensitive, options) - saveTable(df, Some(df.schema), isCaseSensitive, options) + saveTable(df, Some(df.schema), isCaseSensitive, upsert, options) } case SaveMode.Append => + if (options.isUpsert && !dialect.supportsUpsert) { + throw QueryCompilationErrors.tableDoesNotSupportUpsertsError(options.table) + } val tableSchema = JdbcUtils.getSchemaOption(conn, options) - saveTable(df, tableSchema, isCaseSensitive, options) + saveTable(df, tableSchema, isCaseSensitive, upsert, options) case SaveMode.ErrorIfExists => throw QueryCompilationErrors.tableOrViewAlreadyExistsError(options.table) @@ -80,7 +84,7 @@ class JdbcRelationProvider extends CreatableRelationProvider } } else { createTable(conn, options.table, df.schema, isCaseSensitive, options) - saveTable(df, Some(df.schema), isCaseSensitive, options) + saveTable(df, Some(df.schema), isCaseSensitive, upsert, options) } } finally { conn.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index c541ec16fc822..ab278c17c977b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -109,31 +109,46 @@ object JdbcUtils extends Logging with SQLConfHelper { JdbcDialects.get(url).isCascadingTruncateTable() } - /** - * Returns an Insert SQL statement for inserting a row into the target table via JDBC conn. - */ - def getInsertStatement( - table: String, + protected def getInsertColumns( rddSchema: StructType, tableSchema: Option[StructType], isCaseSensitive: Boolean, - dialect: JdbcDialect): String = { - val columns = if (tableSchema.isEmpty) { + dialect: JdbcDialect): Array[StructField] = { + if (tableSchema.isEmpty) { rddSchema.fields } else { - // The generated insert statement needs to follow rddSchema's column sequence and - // tableSchema's column names. When appending data into some case-sensitive DBMSs like - // PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of - // RDD column names for user convenience. rddSchema.fields.map { col => tableSchema.get.find(f => conf.resolver(f.name, col.name)).getOrElse { throw QueryCompilationErrors.columnNotFoundInSchemaError(col, tableSchema) } } } + } + + /** + * Returns an Insert SQL statement for inserting a row into the target table via JDBC conn. + */ + def getInsertStatement( + table: String, + rddSchema: StructType, + tableSchema: Option[StructType], + isCaseSensitive: Boolean, + dialect: JdbcDialect): String = { + val columns = getInsertColumns(rddSchema, tableSchema, isCaseSensitive, dialect) dialect.insertIntoTable(table, columns) } + def getUpsertStatement( + table: String, + rddSchema: StructType, + tableSchema: Option[StructType], + isCaseSensitive: Boolean, + dialect: JdbcDialect, + options: JDBCOptions): String = { + val columns = getInsertColumns(rddSchema, tableSchema, isCaseSensitive, dialect) + dialect.getUpsertStatement(table, columns, isCaseSensitive, options) + } + /** * Retrieve standard jdbc types. * @@ -958,6 +973,7 @@ object JdbcUtils extends Logging with SQLConfHelper { df: DataFrame, tableSchema: Option[StructType], isCaseSensitive: Boolean, + upsert: Boolean, options: JdbcOptionsInWrite): Unit = { val url = options.url val table = options.table @@ -966,7 +982,12 @@ object JdbcUtils extends Logging with SQLConfHelper { val batchSize = options.batchSize val isolationLevel = options.isolationLevel - val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect) + val insertStmt = if (upsert) { + getUpsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect, options) + } else { + getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect) + } + val repartitionedDF = options.numPartitions match { case Some(n) if n <= 0 => throw QueryExecutionErrors.invalidJdbcNumPartitionsError( n, JDBCOptions.JDBC_NUM_PARTITIONS) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala index 7449f66ee020f..1b0bf0f78815b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala @@ -42,7 +42,8 @@ case class JDBCWriteBuilder(schema: StructType, options: JdbcOptionsInWrite) ext val conn = dialect.createConnectionFactory(options)(-1) JdbcUtils.truncateTable(conn, options) } - JdbcUtils.saveTable(data, Some(schema), SQLConf.get.caseSensitiveAnalysis, options) + JdbcUtils.saveTable( + data, Some(schema), SQLConf.get.caseSensitiveAnalysis, upsert = false, options) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 5f69d18cad756..32a006a343980 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -335,6 +335,17 @@ abstract class JdbcDialect extends Serializable with Logging { s"TRUNCATE TABLE $table" } + @Since("3.5.0") + def supportsUpsert(): Boolean = false + + @Since("3.5.0") + def getUpsertStatement( + tableName: String, + columns: Array[StructField], + isCaseSensitive: Boolean, + options: JDBCOptions): String = + throw new UnsupportedOperationException("upserts are not supported") + /** * Override connection specific properties to run before a select is made. This is in place to * allow dialects that need special treatment to optimize behavior. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index d98fcdfd0b23f..aaebe39adea5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -170,6 +170,31 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper { schemaBuilder.result() } + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } + + override def supportsUpsert(): Boolean = true + + override def getUpsertStatement( + tableName: String, + columns: Array[String], + isCaseSensitive: Boolean, + options: JDBCOptions): String = { + val insertColumns = columns.mkString(", ") + val placeholders = columns.map(_ => "?").mkString(",") + val upsertKeyColumns = options.upsertKeyColumns.map(quoteIdentifier) + val updateColumns = columns.filterNot(upsertKeyColumns.contains) + val updateClause = + updateColumns.map(x => s"$x = VALUES($x)").mkString(", ") + + s""" + |INSERT INTO $tableName ($insertColumns) + |VALUES ( $placeholders ) + |ON DUPLICATE KEY UPDATE $updateClause + |""".stripMargin + } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) // See https://dev.mysql.com/doc/refman/8.0/en/alter-table.html @@ -307,7 +332,7 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper { } } catch { case _: Exception => - logWarning("Cannot retrieved index info.") + logWarning("Cannot retrieve index info.") } indexMap.values.toArray } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 34c554f7d37e1..b7c8a93f949cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.{DataSourceScanExec, ExtendedMode, Project import org.apache.spark.sql.execution.command.{ExplainCommand, ShowCreateTableCommand} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRelation, JdbcUtils} +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions.JDBC_UPSERT_KEY_COLUMNS import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper import org.apache.spark.sql.functions.{lit, percentile_approx} import org.apache.spark.sql.internal.SQLConf @@ -1156,6 +1157,26 @@ class JDBCSuite extends QueryTest with SharedSparkSession { assert(db2.getTruncateQuery(table, Some(true)) == db2Query) } + test("upsert table query by dialect") { + val options = { + new JDBCOptions(Map( + JDBC_UPSERT_KEY_COLUMNS -> "id, time", + JDBCOptions.JDBC_URL -> url, + JDBCOptions.JDBC_TABLE_NAME -> "table" + )) + } + val columns = Array("id", "time", "value", "comment") + + val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val quotedColumns = columns.map(MySQL.quoteIdentifier) + val isCaseSensitive = false + val mysqlStmt = MySQL.getUpsertStatement("table", quotedColumns, isCaseSensitive, options) + assert(mysqlStmt === "\n" + + "INSERT INTO table (`id`, `time`, `value`, `comment`)\n" + + "VALUES ( ?,?,?,? )\n" + + "ON DUPLICATE KEY UPDATE `value` = VALUES(`value`), `comment` = VALUES(`comment`)\n") + } + test("Test DataFrame.where for Date and Timestamp") { // Regression test for bug SPARK-11788 val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543"); From 781c713560465b722cf3de33453064e78a0c319b Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Fri, 2 Jun 2023 15:30:05 +0200 Subject: [PATCH 02/12] Implement upsert for Postgres --- .../sql/jdbc/PostgresIntegrationSuite.scala | 46 +++++++++++++++- .../spark/sql/jdbc/PostgresDialect.scala | 26 ++++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 52 +++++++++++++------ 3 files changed, 106 insertions(+), 18 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 8c0a7c0a809f0..bf706efe47d0a 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -24,7 +24,7 @@ import java.time.LocalDateTime import java.util.Properties import org.apache.spark.SparkException -import org.apache.spark.sql.{Column, DataFrame, Row} +import org.apache.spark.sql.{Column, DataFrame, Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -39,6 +39,8 @@ import org.apache.spark.tags.DockerTest */ @DockerTest class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { + import testImplicits._ + override val db = new DatabaseOnDocker { override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.2-alpine") override val env = Map( @@ -190,6 +192,14 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { conn.prepareStatement("CREATE DOMAIN myint AS integer CHECK (VALUE > 0)").executeUpdate() conn.prepareStatement("CREATE TABLE domain_table (c1 myint)").executeUpdate() conn.prepareStatement("INSERT INTO domain_table VALUES (1)").executeUpdate() + + conn.prepareStatement("CREATE TABLE upsert (id integer, ts timestamp, v1 double precision, " + + "v2 double precision, CONSTRAINT pk PRIMARY KEY (id, ts))").executeUpdate() + conn.prepareStatement("INSERT INTO upsert VALUES " + + "(1, '1996-01-01 01:23:45', 1.234, 1.234567), " + + "(1, '1996-01-01 01:23:46', 1.235, 1.234568), " + + "(2, '1996-01-01 01:23:45', 2.345, 2.345678), " + + "(2, '1996-01-01 01:23:46', 2.346, 2.345679)").executeUpdate() } test("Type mapping for various types") { @@ -321,6 +331,40 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { .collect() rows.head.toSeq.tail.foreach(c => assert(c.isInstanceOf[java.sql.Timestamp])) } + + test(s"Upsert existing table") { + val df = Seq( + (1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, 1.234568), // row unchanged + (2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, 2.345678), // updates v1 + (2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, 2.345680), // updates v1 and v2 + (3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, 3.456789) // inserts new row + ).toDF("id", "ts", "v1", "v2").repartition(10) + + val options = Map("numPartitions" -> "10", "upsert" -> "true", "upsertKeyColumns" -> "id, ts") + df.write.mode(SaveMode.Append).options(options).jdbc(jdbcUrl, "upsert", new Properties) + + val actual = spark.read.jdbc(jdbcUrl, "upsert", new Properties).collect.toSet + val expected = Set( + (1, Timestamp.valueOf("1996-01-01 01:23:45"), 1.234, 1.234567), + (1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, 1.234568), + (2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, 2.345678), + (2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, 2.345680), + (3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, 3.456789) + ).map { case (id, ts, v1, v2) => + Row(Integer.valueOf(id), ts, v1.doubleValue(), v2.doubleValue()) + } + assert(actual === expected) + } + + test("SPARK-20557: column type TIMESTAMP with TIME ZONE and TIME with TIME ZONE " + + "should be recognized") { + // When using JDBC to read the columns of TIMESTAMP with TIME ZONE and TIME with TIME ZONE + // the actual types are java.sql.Types.TIMESTAMP and java.sql.Types.TIME + val dfRead = sqlContext.read.jdbc(jdbcUrl, "ts_with_timezone", new Properties) + val rows = dfRead.collect() + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types(1).equals("class java.sql.Timestamp")) + assert(types(2).equals("class java.sql.Timestamp")) } test("SPARK-22291: Conversion error when transforming array types of " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index e4cd79e3f53c7..65797f161ed54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -158,6 +158,32 @@ private case class PostgresDialect() extends JdbcDialect with SQLConfHelper { case _ => None } + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } + + override def supportsUpsert(): Boolean = true + + override def getUpsertStatement( + tableName: String, + columns: Array[String], + isCaseSensitive: Boolean, + options: JDBCOptions): String = { + val insertColumns = columns.mkString(", ") + val placeholders = columns.map(_ => "?").mkString(",") + val upsertKeyColumns = options.upsertKeyColumns.map(quoteIdentifier) + val updateColumns = columns.filterNot(upsertKeyColumns.contains) + val updateClause = + updateColumns.map(x => s"$x = EXCLUDED.$x").mkString(", ") + + s""" + |INSERT INTO $tableName ($insertColumns) + |VALUES ( $placeholders ) + |ON CONFLICT (${upsertKeyColumns.mkString(", ")}) + |DO UPDATE SET $updateClause + |""".stripMargin + } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index b7c8a93f949cb..7f7a7649335af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1157,24 +1157,42 @@ class JDBCSuite extends QueryTest with SharedSparkSession { assert(db2.getTruncateQuery(table, Some(true)) == db2Query) } - test("upsert table query by dialect") { - val options = { - new JDBCOptions(Map( - JDBC_UPSERT_KEY_COLUMNS -> "id, time", - JDBCOptions.JDBC_URL -> url, - JDBCOptions.JDBC_TABLE_NAME -> "table" - )) - } - val columns = Array("id", "time", "value", "comment") + Seq( + ("MySQL", + JdbcDialects.get("jdbc:mysql://127.0.0.1/db"), + """ + |INSERT INTO table (`id`, `time`, `value`, `comment`) + |VALUES ( ?,?,?,? ) + |ON DUPLICATE KEY UPDATE `value` = VALUES(`value`), `comment` = VALUES(`comment`) + |""".stripMargin), + ("Postgres", + JdbcDialects.get("jdbc:postgresql://127.0.0.1/db"), + """ + |INSERT INTO table ("id", "time", "value", "comment") + |VALUES ( ?,?,?,? ) + |ON CONFLICT ("id", "time") + |DO UPDATE SET "value" = EXCLUDED."value", "comment" = EXCLUDED."comment" + |""".stripMargin) + ).foreach { case (label, dialect, expected) => + test(s"upsert table query by dialect - ${dialect.getClass.getSimpleName.stripSuffix("$")}") { + assert(dialect.supportsUpsert() === true) + + val options = { + new JDBCOptions(Map( + JDBC_UPSERT_KEY_COLUMNS -> "id, time", + JDBCOptions.JDBC_URL -> url, + JDBCOptions.JDBC_TABLE_NAME -> "table" + )) + } - val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") - val quotedColumns = columns.map(MySQL.quoteIdentifier) - val isCaseSensitive = false - val mysqlStmt = MySQL.getUpsertStatement("table", quotedColumns, isCaseSensitive, options) - assert(mysqlStmt === "\n" + - "INSERT INTO table (`id`, `time`, `value`, `comment`)\n" + - "VALUES ( ?,?,?,? )\n" + - "ON DUPLICATE KEY UPDATE `value` = VALUES(`value`), `comment` = VALUES(`comment`)\n") + val table = "table" + val columns = Array("id", "time", "value", "comment") + val quotedColumns = columns.map(dialect.quoteIdentifier) + val isCaseSensitive = false + val stmt = dialect.getUpsertStatement(table, quotedColumns, isCaseSensitive, options) + + assert(stmt === expected) + } } test("Test DataFrame.where for Date and Timestamp") { From adedde7c254b3004b164cf36bf48ba1533e6c8b5 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Mon, 5 Jun 2023 14:38:06 +0200 Subject: [PATCH 03/12] Code cleanup --- .../scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index c687893fa47f8..15b8ae2263292 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -265,7 +265,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) val df2 = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) val df3 = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) - df1.write.mode(SaveMode.Append).jdbc(jdbcUrl, "numberscopy", new Properties) + df1.write.jdbc(jdbcUrl, "numberscopy", new Properties) df2.write.jdbc(jdbcUrl, "datescopy", new Properties) df3.write.jdbc(jdbcUrl, "stringscopy", new Properties) } From 6aa37b4ce639a62ab86cf853b0501c36e26670f1 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Thu, 8 Jun 2023 15:42:31 +0200 Subject: [PATCH 04/12] Move upsert tests into trait --- .../jdbc/MsSqlServerIntegrationSuite.scala | 11 ++- .../sql/jdbc/MySQLIntegrationSuite.scala | 41 +---------- .../sql/jdbc/PostgresIntegrationSuite.scala | 30 +------- .../apache/spark/sql/jdbc/UpsertTests.scala | 68 +++++++++++++++++++ 4 files changed, 81 insertions(+), 69 deletions(-) create mode 100644 connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index 8bceb9506e850..c9bc5e3a538f7 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -37,7 +37,7 @@ import org.apache.spark.tags.DockerTest * }}} */ @DockerTest -class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { +class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTests { override val db = new DatabaseOnDocker { override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME", "mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04") @@ -150,6 +150,14 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { """ |INSERT INTO bits VALUES (1, 2, 1) """.stripMargin).executeUpdate() + + conn.prepareStatement("CREATE TABLE upsert (id INT, ts DATETIME, v1 FLOAT, v2 FLOAT, " + + "CONSTRAINT pk_upsert PRIMARY KEY (id, ts))").executeUpdate() + conn.prepareStatement("INSERT INTO upsert VALUES " + + "(1, '1996-01-01 01:23:45', 1.234, 1.234567), " + + "(1, '1996-01-01 01:23:46', 1.235, 1.234568), " + + "(2, '1996-01-01 01:23:45', 2.345, 2.345678), " + + "(2, '1996-01-01 01:23:46', 2.346, 2.345679)").executeUpdate() } test("Basic test") { @@ -437,4 +445,5 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { .load() assert(df.collect().toSet === expectedResult) } + } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index 15b8ae2263292..7ab9eadddbee8 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -25,7 +25,6 @@ import java.util.Properties import scala.util.Using import org.apache.spark.sql.Row -import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.ShortType @@ -40,9 +39,7 @@ import org.apache.spark.tags.DockerTest * }}} */ @DockerTest -class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { - import testImplicits._ - +class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTests { override val db = new MySQLDatabaseOnDocker override def dataPreparation(conn: Connection): Unit = { @@ -369,42 +366,6 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { assert(df.schema.fields.head.dataType === ShortType) } - Seq(false, true).foreach { exists => - test(s"Upsert ${if (exists) "" else "non-"}existing table") { - val df = Seq( - (1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, 1.234568), // row unchanged - (2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, 2.345678), // updates v1 - (2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, 2.345680), // updates v1 and v2 - (3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, 3.456789) // inserts new row - ).toDF("id", "ts", "v1", "v2").repartition(10) - - val table = if (exists) "upsert" else "new_table" - val options = Map("numPartitions" -> "10", "upsert" -> "true") - df.write.mode(SaveMode.Append).options(options).jdbc(jdbcUrl, table, new Properties) - - val actual = spark.read.jdbc(jdbcUrl, table, new Properties).collect.toSet - val existing = if (exists) { - Set((1, Timestamp.valueOf("1996-01-01 01:23:45"), 1.234, 1.234567)) - } else { - Set.empty - } - val upsertedRows = Set( - (1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, 1.234568), - (2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, 2.345678), - (2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, 2.345680), - (3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, 3.456789) - ) - val expected = (existing ++ upsertedRows).map { case (id, ts, v1, v2) => - Row(Integer.valueOf(id), ts, v1.doubleValue(), v2.doubleValue()) - } - assert(actual === expected) - } - } - - test("Write with unspecified mode with upsert") { } - test("Write with overwrite mode with upsert") { } - test("Write with error-if-exists mode with upsert") { } - test("Write with ignore mode with upsert") { } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index bf706efe47d0a..2ace2da17889c 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -24,7 +24,7 @@ import java.time.LocalDateTime import java.util.Properties import org.apache.spark.SparkException -import org.apache.spark.sql.{Column, DataFrame, Row, SaveMode} +import org.apache.spark.sql.{Column, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -38,9 +38,7 @@ import org.apache.spark.tags.DockerTest * }}} */ @DockerTest -class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { - import testImplicits._ - +class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTests { override val db = new DatabaseOnDocker { override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.2-alpine") override val env = Map( @@ -332,30 +330,6 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { rows.head.toSeq.tail.foreach(c => assert(c.isInstanceOf[java.sql.Timestamp])) } - test(s"Upsert existing table") { - val df = Seq( - (1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, 1.234568), // row unchanged - (2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, 2.345678), // updates v1 - (2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, 2.345680), // updates v1 and v2 - (3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, 3.456789) // inserts new row - ).toDF("id", "ts", "v1", "v2").repartition(10) - - val options = Map("numPartitions" -> "10", "upsert" -> "true", "upsertKeyColumns" -> "id, ts") - df.write.mode(SaveMode.Append).options(options).jdbc(jdbcUrl, "upsert", new Properties) - - val actual = spark.read.jdbc(jdbcUrl, "upsert", new Properties).collect.toSet - val expected = Set( - (1, Timestamp.valueOf("1996-01-01 01:23:45"), 1.234, 1.234567), - (1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, 1.234568), - (2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, 2.345678), - (2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, 2.345680), - (3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, 3.456789) - ).map { case (id, ts, v1, v2) => - Row(Integer.valueOf(id), ts, v1.doubleValue(), v2.doubleValue()) - } - assert(actual === expected) - } - test("SPARK-20557: column type TIMESTAMP with TIME ZONE and TIME with TIME ZONE " + "should be recognized") { // When using JDBC to read the columns of TIMESTAMP with TIME ZONE and TIME with TIME ZONE diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala new file mode 100644 index 0000000000000..fd1d400665b89 --- /dev/null +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Timestamp +import java.util.Properties + +import org.apache.spark.sql.{Row, SaveMode} + +trait UpsertTests { + self: DockerJDBCIntegrationSuite => + + import testImplicits._ + + test(s"Upsert existing table") { doTestUpsert(true) } + test(s"Upsert non-existing table") { doTestUpsert(false) } + + def doTestUpsert(tableExists: Boolean): Unit = { + val df = Seq( + (1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, 1.234568), // row unchanged + (2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, 2.345678), // updates v1 + (2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, 2.345680), // updates v1 and v2 + (3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, 3.456789) // inserts new row + ).toDF("id", "ts", "v1", "v2").repartition(1) // .repartition(10) + + val table = if (tableExists) "upsert" else "new_table" + val options = Map("numPartitions" -> "10", "upsert" -> "true", "upsertKeyColumns" -> "id, ts") + df.write.mode(SaveMode.Append).options(options).jdbc(jdbcUrl, table, new Properties) + + val actual = spark.read.jdbc(jdbcUrl, table, new Properties).collect.toSet + val existing = if (tableExists) { + Set((1, Timestamp.valueOf("1996-01-01 01:23:45"), 1.234, 1.234567)) + } else { + Set.empty + } + val upsertedRows = Set( + (1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, 1.234568), + (2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, 2.345678), + (2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, 2.345680), + (3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, 3.456789) + ) + val expected = (existing ++ upsertedRows).map { case (id, ts, v1, v2) => + Row(Integer.valueOf(id), ts, v1.doubleValue(), v2.doubleValue()) + } + assert(actual === expected) + } + + test("Upsert null values") {} + test("Write with unspecified mode with upsert") {} + test("Write with overwrite mode with upsert") {} + test("Write with error-if-exists mode with upsert") {} + test("Write with ignore mode with upsert") {} +} From 5ce43daea797c483d61ce2655e38ff9ceace2fd6 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Thu, 8 Jun 2023 15:39:20 +0200 Subject: [PATCH 05/12] Implement upsert for MsSqlServer --- .../spark/sql/jdbc/MsSqlServerDialect.scala | 51 ++++++++++++++++- .../apache/spark/sql/jdbc/MySQLDialect.scala | 4 +- .../spark/sql/jdbc/PostgresDialect.scala | 1 + .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 55 +++++++++++++------ 4 files changed, 91 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 862e99adc3b0d..1b10dbd2893f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.expressions.{Expression, NullOrdering, SortDirection} import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.jdbc.MsSqlServerDialect.{GEOGRAPHY, GEOMETRY} import org.apache.spark.sql.types._ @@ -139,6 +139,55 @@ private case class MsSqlServerDialect() extends JdbcDialect { case _ => None } + override def supportsUpsert(): Boolean = true + + override def getUpsertStatement( + tableName: String, + columns: Array[String], + types: Array[DataType], + isCaseSensitive: Boolean, + options: JDBCOptions): String = { + val insertColumns = columns.mkString(", ") + val inputs = types + .map(t => JdbcUtils.getJdbcType(t, this).databaseTypeDefinition) + .zipWithIndex.map { + case (t, idx) => s"DECLARE @param$idx $t; SET @param$idx = ?;" + }.mkString("\n") + val values = columns.indices.map(i => s"@param$i").mkString(", ") + val quotedUpsertKeyColumns = options.upsertKeyColumns.map(quoteIdentifier) + val keyColumns = columns.zipWithIndex.filter { + case (col, _) => quotedUpsertKeyColumns.contains(col) + } + val updateColumns = columns.zipWithIndex.filterNot { + case (col, _) => quotedUpsertKeyColumns.contains(col) + } + val whereClause = keyColumns.map { + case (key, idx) => s"$key = @param$idx" + }.mkString(" AND ") + val updateClause = updateColumns.map { + case (col, idx) => s"$col = @param$idx" + }.mkString(", ") + + s""" + |$inputs + | + |INSERT $tableName ($insertColumns) + |SELECT $values + |WHERE NOT EXISTS ( + | SELECT 1 + | FROM $tableName WITH (UPDLOCK, SERIALIZABLE) + | WHERE $whereClause + |) + | + |IF (@@ROWCOUNT = 0) + |BEGIN + | UPDATE TOP (1) $tableName + | SET $updateClause + | WHERE $whereClause + |END + |""".stripMargin + } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) // scalastyle:off line.size.limit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index aaebe39adea5e..b5e309a7dee96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -178,13 +178,13 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper { override def getUpsertStatement( tableName: String, - columns: Array[String], + columns: Array[StructField], isCaseSensitive: Boolean, options: JDBCOptions): String = { val insertColumns = columns.mkString(", ") val placeholders = columns.map(_ => "?").mkString(",") val upsertKeyColumns = options.upsertKeyColumns.map(quoteIdentifier) - val updateColumns = columns.filterNot(upsertKeyColumns.contains) + val updateColumns = columns.filterNot(c => upsertKeyColumns.contains(c.name)) val updateClause = updateColumns.map(x => s"$x = VALUES($x)").mkString(", ") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 65797f161ed54..f7f3b9272824b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -167,6 +167,7 @@ private case class PostgresDialect() extends JdbcDialect with SQLConfHelper { override def getUpsertStatement( tableName: String, columns: Array[String], + types: Array[DataType], isCaseSensitive: Boolean, options: JDBCOptions): String = { val insertColumns = columns.mkString(", ") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 7f7a7649335af..b995a8c67310d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1158,22 +1158,42 @@ class JDBCSuite extends QueryTest with SharedSparkSession { } Seq( - ("MySQL", - JdbcDialects.get("jdbc:mysql://127.0.0.1/db"), - """ - |INSERT INTO table (`id`, `time`, `value`, `comment`) - |VALUES ( ?,?,?,? ) - |ON DUPLICATE KEY UPDATE `value` = VALUES(`value`), `comment` = VALUES(`comment`) - |""".stripMargin), - ("Postgres", - JdbcDialects.get("jdbc:postgresql://127.0.0.1/db"), - """ - |INSERT INTO table ("id", "time", "value", "comment") - |VALUES ( ?,?,?,? ) - |ON CONFLICT ("id", "time") - |DO UPDATE SET "value" = EXCLUDED."value", "comment" = EXCLUDED."comment" - |""".stripMargin) - ).foreach { case (label, dialect, expected) => + (JdbcDialects.get("jdbc:mysql://127.0.0.1/db"), + """ + |INSERT INTO table (`id`, `time`, `value`, `comment`) + |VALUES ( ?,?,?,? ) + |ON DUPLICATE KEY UPDATE `value` = VALUES(`value`), `comment` = VALUES(`comment`) + |""".stripMargin), + (JdbcDialects.get("jdbc:postgresql://127.0.0.1/db"), + """ + |INSERT INTO table ("id", "time", "value", "comment") + |VALUES ( ?,?,?,? ) + |ON CONFLICT ("id", "time") + |DO UPDATE SET "value" = EXCLUDED."value", "comment" = EXCLUDED."comment" + |""".stripMargin), + (JdbcDialects.get("jdbc:sqlserver://localhost/db"), + """ + |DECLARE @param0 BIGINT; SET @param0 = ?; + |DECLARE @param1 DATETIME; SET @param1 = ?; + |DECLARE @param2 DOUBLE PRECISION; SET @param2 = ?; + |DECLARE @param3 NVARCHAR(MAX); SET @param3 = ?; + | + |INSERT table ("id", "time", "value", "comment") + |SELECT @param0, @param1, @param2, @param3 + |WHERE NOT EXISTS ( + | SELECT 1 + | FROM table WITH (UPDLOCK, SERIALIZABLE) + | WHERE "id" = @param0 AND "time" = @param1 + |) + | + |IF (@@ROWCOUNT = 0) + |BEGIN + | UPDATE TOP (1) table + | SET "value" = @param2, "comment" = @param3 + | WHERE "id" = @param0 AND "time" = @param1 + |END + |""".stripMargin) + ).foreach { case (dialect, expected) => test(s"upsert table query by dialect - ${dialect.getClass.getSimpleName.stripSuffix("$")}") { assert(dialect.supportsUpsert() === true) @@ -1188,8 +1208,9 @@ class JDBCSuite extends QueryTest with SharedSparkSession { val table = "table" val columns = Array("id", "time", "value", "comment") val quotedColumns = columns.map(dialect.quoteIdentifier) + val types: Array[DataType] = Array(LongType, TimestampType, DoubleType, StringType) val isCaseSensitive = false - val stmt = dialect.getUpsertStatement(table, quotedColumns, isCaseSensitive, options) + val stmt = dialect.getUpsertStatement(table, quotedColumns, types, isCaseSensitive, options) assert(stmt === expected) } From 9d9a8fe6c919ba89fbc4cdae8bdfff2312743a8b Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Thu, 8 Jun 2023 15:54:35 +0200 Subject: [PATCH 06/12] Fix non-existing upsert test for Postgres --- .../spark/sql/jdbc/PostgresIntegrationSuite.scala | 4 ++++ .../scala/org/apache/spark/sql/jdbc/UpsertTests.scala | 10 ++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 2ace2da17889c..c7d53eb75e38c 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -200,6 +200,10 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTes "(2, '1996-01-01 01:23:46', 2.346, 2.345679)").executeUpdate() } + override val upsertTestOptions = Map( + "createTableOptions" -> "; ALTER TABLE new_upsert_table ADD PRIMARY KEY (id, ts)" + ) + test("Type mapping for various types") { val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) val rows = df.collect().sortBy(_.toString()) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala index fd1d400665b89..a06b3baf998b6 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala @@ -27,6 +27,8 @@ trait UpsertTests { import testImplicits._ + def upsertTestOptions: Map[String, String] = Map.empty + test(s"Upsert existing table") { doTestUpsert(true) } test(s"Upsert non-existing table") { doTestUpsert(false) } @@ -38,8 +40,12 @@ trait UpsertTests { (3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, 3.456789) // inserts new row ).toDF("id", "ts", "v1", "v2").repartition(1) // .repartition(10) - val table = if (tableExists) "upsert" else "new_table" - val options = Map("numPartitions" -> "10", "upsert" -> "true", "upsertKeyColumns" -> "id, ts") + val table = if (tableExists) "upsert" else "new_upsert_table" + val options = upsertTestOptions ++ Map( + "numPartitions" -> "10", + "upsert" -> "true", + "upsertKeyColumns" -> "id, ts" + ) df.write.mode(SaveMode.Append).options(options).jdbc(jdbcUrl, table, new Properties) val actual = spark.read.jdbc(jdbcUrl, table, new Properties).collect.toSet From 7fd80b5e042e028eab1049d020cc482bc2e60605 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Thu, 8 Jun 2023 16:58:33 +0200 Subject: [PATCH 07/12] Add upsert concurrency integration test --- .../jdbc/MsSqlServerIntegrationSuite.scala | 2 + .../sql/jdbc/MySQLIntegrationSuite.scala | 4 +- .../sql/jdbc/PostgresIntegrationSuite.scala | 4 +- .../apache/spark/sql/jdbc/UpsertTests.scala | 50 ++++++++++++++++++- 4 files changed, 55 insertions(+), 5 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index c9bc5e3a538f7..5d98d17097a35 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -160,6 +160,8 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite with Upsert "(2, '1996-01-01 01:23:46', 2.346, 2.345679)").executeUpdate() } + override val createTableOption = "; ALTER TABLE new_upsert_table ADD PRIMARY KEY (id, ts)" + test("Basic test") { val df = spark.read.jdbc(jdbcUrl, "tbl", new Properties) val rows = df.collect() diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index 7ab9eadddbee8..97bfeafcb9e7c 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -107,6 +107,8 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTests "(2, '1996-01-01 01:23:46', 2.346, 2.345679)").executeUpdate() } + override val createTableOption = "; ALTER TABLE new_upsert_table ADD PRIMARY KEY (id, ts)" + def testConnection(): Unit = { Using.resource(getConnection()) { conn => assert(conn.getClass.getName === "com.mysql.cj.jdbc.ConnectionImpl") @@ -383,7 +385,7 @@ class MySQLOverMariaConnectorIntegrationSuite extends MySQLIntegrationSuite { override val db = new MySQLDatabaseOnDocker { override def getJdbcUrl(ip: String, port: Int): String = s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass&allowPublicKeyRetrieval=true" + - s"&useSSL=false" + s"&useSSL=false&allowMultiQueries=true" } override def testConnection(): Unit = { diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index c7d53eb75e38c..1738ff56f896e 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -200,9 +200,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTes "(2, '1996-01-01 01:23:46', 2.346, 2.345679)").executeUpdate() } - override val upsertTestOptions = Map( - "createTableOptions" -> "; ALTER TABLE new_upsert_table ADD PRIMARY KEY (id, ts)" - ) + override val createTableOption = "; ALTER TABLE new_upsert_table ADD PRIMARY KEY (id, ts)" test("Type mapping for various types") { val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala index a06b3baf998b6..7230b1a7001c2 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala @@ -21,13 +21,15 @@ import java.sql.Timestamp import java.util.Properties import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.functions.{lit, rand, when} trait UpsertTests { self: DockerJDBCIntegrationSuite => import testImplicits._ - def upsertTestOptions: Map[String, String] = Map.empty + def createTableOption: String + def upsertTestOptions: Map[String, String] = Map("createTableOptions" -> createTableOption) test(s"Upsert existing table") { doTestUpsert(true) } test(s"Upsert non-existing table") { doTestUpsert(false) } @@ -66,6 +68,52 @@ trait UpsertTests { assert(actual === expected) } + test(s"Upsert concurrency") { + // create a table with 100k rows + val init = + spark.range(100000) + .withColumn("ts", lit(Timestamp.valueOf("2023-06-07 12:34:56"))) + .withColumn("v", rand()) + + // upsert 100 batches of 100 rows each + // run in 32 tasks + val patch = + spark.range(100) + .join(spark.range(100).select(($"id" * 1000).as("offset"))) + .repartition(32) + .select( + ($"id" + $"offset").as("id"), + lit(Timestamp.valueOf("2023-06-07 12:34:56")).as("ts"), + lit(-1.0).as("v") + ) + + spark.sparkContext.setJobDescription("init") + init + .write + .mode(SaveMode.Overwrite) + .option("createTableOptions", createTableOption) + .jdbc(jdbcUrl, "new_upsert_table", new Properties) + + spark.sparkContext.setJobDescription("patch") + patch + .write + .mode(SaveMode.Append) + .option("upsert", true) + .option("upsertKeyColumns", "id, ts") + .options(upsertTestOptions) + .jdbc(jdbcUrl, "new_upsert_table", new Properties) + + // check result table has 100*100 updated rows + val result = spark.read.jdbc(jdbcUrl, "new_upsert_table", new Properties) + .select($"id", when($"v" === -1.0, true).otherwise(false).as("patched")) + .groupBy($"patched") + .count() + .sort($"patched") + .as[(Boolean, Long)] + .collect() + assert(result === Seq((false, 90000), (true, 10000))) + } + test("Upsert null values") {} test("Write with unspecified mode with upsert") {} test("Write with overwrite mode with upsert") {} From 7602df499db3355a1f67ee4bb365bf0f2f776ea2 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Fri, 9 Jun 2023 10:38:06 +0200 Subject: [PATCH 08/12] Add tests with varying column order --- .../org/apache/spark/sql/jdbc/UpsertTests.scala | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala index 7230b1a7001c2..8abaa2d7a46f2 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala @@ -34,7 +34,18 @@ trait UpsertTests { test(s"Upsert existing table") { doTestUpsert(true) } test(s"Upsert non-existing table") { doTestUpsert(false) } - def doTestUpsert(tableExists: Boolean): Unit = { + Seq( + Seq("ts", "id", "v1", "v2"), + Seq("ts", "v1", "id", "v2"), + Seq("ts", "v1", "v2", "id"), + Seq("v2", "v1", "ts", "id"), + ).foreach { columns => + test(s"Upsert with varying column order - ${columns.mkString(",")}") { + doTestUpsert(tableExists = true, Some(columns)) + } + } + + def doTestUpsert(tableExists: Boolean, project: Option[Seq[String]] = None): Unit = { val df = Seq( (1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, 1.234568), // row unchanged (2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, 2.345678), // updates v1 @@ -48,7 +59,8 @@ trait UpsertTests { "upsert" -> "true", "upsertKeyColumns" -> "id, ts" ) - df.write.mode(SaveMode.Append).options(options).jdbc(jdbcUrl, table, new Properties) + project.map(df.select(_: _*)).getOrElse(df) + .write.mode(SaveMode.Append).options(options).jdbc(jdbcUrl, table, new Properties) val actual = spark.read.jdbc(jdbcUrl, table, new Properties).collect.toSet val existing = if (tableExists) { From 1b246448c35901ff20b71321f8a1b5533154d34e Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Fri, 9 Jun 2023 12:23:16 +0200 Subject: [PATCH 09/12] Add test with varying column order, sketch more tests --- .../org/apache/spark/sql/jdbc/UpsertTests.scala | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala index 8abaa2d7a46f2..8528a0fa92ab4 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala @@ -21,7 +21,7 @@ import java.sql.Timestamp import java.util.Properties import org.apache.spark.sql.{Row, SaveMode} -import org.apache.spark.sql.functions.{lit, rand, when} +import org.apache.spark.sql.functions.{col, lit, rand, when} trait UpsertTests { self: DockerJDBCIntegrationSuite => @@ -31,8 +31,8 @@ trait UpsertTests { def createTableOption: String def upsertTestOptions: Map[String, String] = Map("createTableOptions" -> createTableOption) - test(s"Upsert existing table") { doTestUpsert(true) } - test(s"Upsert non-existing table") { doTestUpsert(false) } + test(s"Upsert existing table") { doTestUpsert(tableExists = true) } + test(s"Upsert non-existing table") { doTestUpsert(tableExists = false) } Seq( Seq("ts", "id", "v1", "v2"), @@ -59,10 +59,10 @@ trait UpsertTests { "upsert" -> "true", "upsertKeyColumns" -> "id, ts" ) - project.map(df.select(_: _*)).getOrElse(df) + project.map(columns => df.select(columns.map(col): _*)).getOrElse(df) .write.mode(SaveMode.Append).options(options).jdbc(jdbcUrl, table, new Properties) - val actual = spark.read.jdbc(jdbcUrl, table, new Properties).collect.toSet + val actual = spark.read.jdbc(jdbcUrl, table, new Properties).collect().toSet val existing = if (tableExists) { Set((1, Timestamp.valueOf("1996-01-01 01:23:45"), 1.234, 1.234567)) } else { @@ -110,7 +110,7 @@ trait UpsertTests { patch .write .mode(SaveMode.Append) - .option("upsert", true) + .option("upsert", value = true) .option("upsertKeyColumns", "id, ts") .options(upsertTestOptions) .jdbc(jdbcUrl, "new_upsert_table", new Properties) @@ -126,7 +126,10 @@ trait UpsertTests { assert(result === Seq((false, 90000), (true, 10000))) } + test("Upsert with columns that require quotes") {} + test("Upsert with table name that requires quotes") {} test("Upsert null values") {} + test("Write with unspecified mode with upsert") {} test("Write with overwrite mode with upsert") {} test("Write with error-if-exists mode with upsert") {} From e0e65d3c61659e1de9191953318b548c1f6d4f43 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Thu, 15 Jun 2023 11:25:32 +0200 Subject: [PATCH 10/12] Revert empty line removal, fix scalastyle error --- .../src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala | 2 +- .../spark/sql/execution/datasources/jdbc/JDBCOptions.scala | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala index 8528a0fa92ab4..6b9cbfb3bd4cf 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala @@ -38,7 +38,7 @@ trait UpsertTests { Seq("ts", "id", "v1", "v2"), Seq("ts", "v1", "id", "v2"), Seq("ts", "v1", "v2", "id"), - Seq("v2", "v1", "ts", "id"), + Seq("v2", "v1", "ts", "id") ).foreach { columns => test(s"Upsert with varying column order - ${columns.mkString(",")}") { doTestUpsert(tableExists = true, Some(columns)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 7379feaa0e24c..0b6a02e161244 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -160,6 +160,7 @@ class JDBCOptions( // ------------------------------------------------------------ // if to truncate the table from the JDBC database val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean + val isCascadeTruncate: Option[Boolean] = parameters.get(JDBC_CASCADE_TRUNCATE).map(_.toBoolean) // if to upsert the table in the JDBC database val isUpsert = parameters.getOrElse(JDBC_UPSERT, "false").toBoolean From fc707f14be916d2ae957c5665328a53a2097c13b Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Tue, 18 Jul 2023 14:13:24 +0200 Subject: [PATCH 11/12] Refactor tableDoesNotSupportError to reuse in tableDoesNotSupportUpsertsError --- .../sql/errors/QueryCompilationErrors.scala | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index c11a577270476..f48d3f7c75ee9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1471,44 +1471,40 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat messageParameters = Map("changes" -> changes.toString())) } - private def tableDoesNotSupportError(cmd: String, table: Table): Throwable = { + private def tableDoesNotSupportError(cmd: String, tableName: String): Throwable = { new AnalysisException( errorClass = "_LEGACY_ERROR_TEMP_1121", messageParameters = Map( "cmd" -> cmd, - "table" -> table.name)) + "table" -> tableName)) } def tableDoesNotSupportReadsError(table: Table): Throwable = { - tableDoesNotSupportError("reads", table) + tableDoesNotSupportError("reads", table.name) } def tableDoesNotSupportWritesError(table: Table): Throwable = { - tableDoesNotSupportError("writes", table) + tableDoesNotSupportError("writes", table.name) } def tableDoesNotSupportDeletesError(table: Table): Throwable = { - tableDoesNotSupportError("deletes", table) + tableDoesNotSupportError("deletes", table.name) } def tableDoesNotSupportTruncatesError(table: Table): Throwable = { - tableDoesNotSupportError("truncates", table) + tableDoesNotSupportError("truncates", table.name) } - def tableDoesNotSupportUpsertsError(table: String): Throwable = { - new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1121", - messageParameters = Map( - "cmd" -> "upserts", - "table" -> table)) + def tableDoesNotSupportUpsertsError(tableName: String): Throwable = { + tableDoesNotSupportError("upserts", tableName) } def tableDoesNotSupportPartitionManagementError(table: Table): Throwable = { - tableDoesNotSupportError("partition management", table) + tableDoesNotSupportError("partition management", table.name) } def tableDoesNotSupportAtomicPartitionManagementError(table: Table): Throwable = { - tableDoesNotSupportError("atomic partition management", table) + tableDoesNotSupportError("atomic partition management", table.name) } def tableIsNotRowLevelOperationTableError(table: Table): Throwable = { From 803e5ded70919f8c901b3f23dcc11800823cf394 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Wed, 25 Oct 2023 09:13:48 +0200 Subject: [PATCH 12/12] Fix after merge master --- .../jdbc/JdbcRelationProvider.scala | 2 +- .../spark/sql/jdbc/MsSqlServerDialect.scala | 19 +++++++++---------- .../apache/spark/sql/jdbc/MySQLDialect.scala | 6 +++--- .../spark/sql/jdbc/PostgresDialect.scala | 9 ++++----- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 11 +++++++---- 5 files changed, 24 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index d69d2cc58e077..0808d16749ff8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -68,7 +68,7 @@ class JdbcRelationProvider extends CreatableRelationProvider } case SaveMode.Append => - if (options.isUpsert && !dialect.supportsUpsert) { + if (options.isUpsert && !dialect.supportsUpsert()) { throw QueryCompilationErrors.tableDoesNotSupportUpsertsError(options.table) } val tableSchema = JdbcUtils.getSchemaOption(conn, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 1b10dbd2893f4..d1c73c2100ab2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -143,35 +143,34 @@ private case class MsSqlServerDialect() extends JdbcDialect { override def getUpsertStatement( tableName: String, - columns: Array[String], - types: Array[DataType], + columns: Array[StructField], isCaseSensitive: Boolean, options: JDBCOptions): String = { - val insertColumns = columns.mkString(", ") - val inputs = types + val insertColumns = columns.map(_.name).map(quoteIdentifier) + val inputs = columns + .map(_.dataType) .map(t => JdbcUtils.getJdbcType(t, this).databaseTypeDefinition) .zipWithIndex.map { case (t, idx) => s"DECLARE @param$idx $t; SET @param$idx = ?;" }.mkString("\n") val values = columns.indices.map(i => s"@param$i").mkString(", ") - val quotedUpsertKeyColumns = options.upsertKeyColumns.map(quoteIdentifier) val keyColumns = columns.zipWithIndex.filter { - case (col, _) => quotedUpsertKeyColumns.contains(col) + case (col, _) => options.upsertKeyColumns.contains(col.name) } val updateColumns = columns.zipWithIndex.filterNot { - case (col, _) => quotedUpsertKeyColumns.contains(col) + case (col, _) => options.upsertKeyColumns.contains(col.name) } val whereClause = keyColumns.map { - case (key, idx) => s"$key = @param$idx" + case (key, idx) => s"${quoteIdentifier(key.name)} = @param$idx" }.mkString(" AND ") val updateClause = updateColumns.map { - case (col, idx) => s"$col = @param$idx" + case (col, idx) => s"${quoteIdentifier(col.name)} = @param$idx" }.mkString(", ") s""" |$inputs | - |INSERT $tableName ($insertColumns) + |INSERT $tableName (${insertColumns.mkString(", ")}) |SELECT $values |WHERE NOT EXISTS ( | SELECT 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index b5e309a7dee96..f954eda39e804 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -181,15 +181,15 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper { columns: Array[StructField], isCaseSensitive: Boolean, options: JDBCOptions): String = { - val insertColumns = columns.mkString(", ") + val insertColumns = columns.map(_.name).map(quoteIdentifier) val placeholders = columns.map(_ => "?").mkString(",") val upsertKeyColumns = options.upsertKeyColumns.map(quoteIdentifier) - val updateColumns = columns.filterNot(c => upsertKeyColumns.contains(c.name)) + val updateColumns = insertColumns.filterNot(upsertKeyColumns.contains) val updateClause = updateColumns.map(x => s"$x = VALUES($x)").mkString(", ") s""" - |INSERT INTO $tableName ($insertColumns) + |INSERT INTO $tableName (${insertColumns.mkString(", ")}) |VALUES ( $placeholders ) |ON DUPLICATE KEY UPDATE $updateClause |""".stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index f7f3b9272824b..b2a359a0a2ac0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -166,19 +166,18 @@ private case class PostgresDialect() extends JdbcDialect with SQLConfHelper { override def getUpsertStatement( tableName: String, - columns: Array[String], - types: Array[DataType], + columns: Array[StructField], isCaseSensitive: Boolean, options: JDBCOptions): String = { - val insertColumns = columns.mkString(", ") + val insertColumns = columns.map(_.name).map(quoteIdentifier) val placeholders = columns.map(_ => "?").mkString(",") val upsertKeyColumns = options.upsertKeyColumns.map(quoteIdentifier) - val updateColumns = columns.filterNot(upsertKeyColumns.contains) + val updateColumns = insertColumns.filterNot(upsertKeyColumns.contains) val updateClause = updateColumns.map(x => s"$x = EXCLUDED.$x").mkString(", ") s""" - |INSERT INTO $tableName ($insertColumns) + |INSERT INTO $tableName (${insertColumns.mkString(", ")}) |VALUES ( $placeholders ) |ON CONFLICT (${upsertKeyColumns.mkString(", ")}) |DO UPDATE SET $updateClause diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index b995a8c67310d..2bb705c75c987 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1206,11 +1206,14 @@ class JDBCSuite extends QueryTest with SharedSparkSession { } val table = "table" - val columns = Array("id", "time", "value", "comment") - val quotedColumns = columns.map(dialect.quoteIdentifier) - val types: Array[DataType] = Array(LongType, TimestampType, DoubleType, StringType) + val columns = Array( + StructField("id", LongType), + StructField("time", TimestampType), + StructField("value", DoubleType), + StructField("comment", StringType) + ) val isCaseSensitive = false - val stmt = dialect.getUpsertStatement(table, quotedColumns, types, isCaseSensitive, options) + val stmt = dialect.getUpsertStatement(table, columns, isCaseSensitive, options) assert(stmt === expected) }