Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.tags.DockerTest
* }}}
*/
@DockerTest
class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTests {
override val db = new DatabaseOnDocker {
override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME",
"mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04")
Expand Down Expand Up @@ -150,8 +150,18 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
"""
|INSERT INTO bits VALUES (1, 2, 1)
""".stripMargin).executeUpdate()

conn.prepareStatement("CREATE TABLE upsert (id INT, ts DATETIME, v1 FLOAT, v2 FLOAT, " +
"CONSTRAINT pk_upsert PRIMARY KEY (id, ts))").executeUpdate()
conn.prepareStatement("INSERT INTO upsert VALUES " +
"(1, '1996-01-01 01:23:45', 1.234, 1.234567), " +
"(1, '1996-01-01 01:23:46', 1.235, 1.234568), " +
"(2, '1996-01-01 01:23:45', 2.345, 2.345678), " +
"(2, '1996-01-01 01:23:46', 2.346, 2.345679)").executeUpdate()
}

override val createTableOption = "; ALTER TABLE new_upsert_table ADD PRIMARY KEY (id, ts)"

test("Basic test") {
val df = spark.read.jdbc(jdbcUrl, "tbl", new Properties)
val rows = df.collect()
Expand Down Expand Up @@ -437,4 +447,5 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
.load()
assert(df.collect().toSet === expectedResult)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.tags.DockerTest
* }}}
*/
@DockerTest
class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTests {
override val db = new MySQLDatabaseOnDocker

override def dataPreparation(conn: Connection): Unit = {
Expand Down Expand Up @@ -97,8 +97,18 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
conn.prepareStatement("CREATE TABLE TBL_GEOMETRY (col0 GEOMETRY)").executeUpdate()
conn.prepareStatement("INSERT INTO TBL_GEOMETRY VALUES (ST_GeomFromText('POINT(0 0)'))")
.executeUpdate()

conn.prepareStatement("CREATE TABLE upsert (id INTEGER, ts TIMESTAMP, v1 DOUBLE, v2 DOUBLE, " +
"PRIMARY KEY pk (id, ts))").executeUpdate()
conn.prepareStatement("INSERT INTO upsert VALUES " +
"(1, '1996-01-01 01:23:45', 1.234, 1.234567), " +
"(1, '1996-01-01 01:23:46', 1.235, 1.234568), " +
"(2, '1996-01-01 01:23:45', 2.345, 2.345678), " +
"(2, '1996-01-01 01:23:46', 2.346, 2.345679)").executeUpdate()
}

override val createTableOption = "; ALTER TABLE new_upsert_table ADD PRIMARY KEY (id, ts)"

def testConnection(): Unit = {
Using.resource(getConnection()) { conn =>
assert(conn.getClass.getName === "com.mysql.cj.jdbc.ConnectionImpl")
Expand Down Expand Up @@ -285,7 +295,6 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(sql("select x, y from queryOption").collect().toSet == expectedResult)
}


test("SPARK-47478: all boolean synonyms read-write roundtrip") {
val df = sqlContext.read.jdbc(jdbcUrl, "bools", new Properties)
checkAnswer(df, Row(true, true, true))
Expand Down Expand Up @@ -358,6 +367,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
val df = spark.read.jdbc(jdbcUrl, "smallint_round_trip", new Properties)
assert(df.schema.fields.head.dataType === ShortType)
}

}


Expand All @@ -375,7 +385,7 @@ class MySQLOverMariaConnectorIntegrationSuite extends MySQLIntegrationSuite {
override val db = new MySQLDatabaseOnDocker {
override def getJdbcUrl(ip: String, port: Int): String =
s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass&allowPublicKeyRetrieval=true" +
s"&useSSL=false"
s"&useSSL=false&allowMultiQueries=true"
}

override def testConnection(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.tags.DockerTest
* }}}
*/
@DockerTest
class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTests {
override val db = new DatabaseOnDocker {
override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.2-alpine")
override val env = Map(
Expand Down Expand Up @@ -190,8 +190,18 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
conn.prepareStatement("CREATE DOMAIN myint AS integer CHECK (VALUE > 0)").executeUpdate()
conn.prepareStatement("CREATE TABLE domain_table (c1 myint)").executeUpdate()
conn.prepareStatement("INSERT INTO domain_table VALUES (1)").executeUpdate()

conn.prepareStatement("CREATE TABLE upsert (id integer, ts timestamp, v1 double precision, " +
"v2 double precision, CONSTRAINT pk PRIMARY KEY (id, ts))").executeUpdate()
conn.prepareStatement("INSERT INTO upsert VALUES " +
"(1, '1996-01-01 01:23:45', 1.234, 1.234567), " +
"(1, '1996-01-01 01:23:46', 1.235, 1.234568), " +
"(2, '1996-01-01 01:23:45', 2.345, 2.345678), " +
"(2, '1996-01-01 01:23:46', 2.346, 2.345679)").executeUpdate()
}

override val createTableOption = "; ALTER TABLE new_upsert_table ADD PRIMARY KEY (id, ts)"

test("Type mapping for various types") {
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
val rows = df.collect().sortBy(_.toString())
Expand Down Expand Up @@ -321,6 +331,16 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
.collect()
rows.head.toSeq.tail.foreach(c => assert(c.isInstanceOf[java.sql.Timestamp]))
}

test("SPARK-20557: column type TIMESTAMP with TIME ZONE and TIME with TIME ZONE " +
"should be recognized") {
// When using JDBC to read the columns of TIMESTAMP with TIME ZONE and TIME with TIME ZONE
// the actual types are java.sql.Types.TIMESTAMP and java.sql.Types.TIME
val dfRead = sqlContext.read.jdbc(jdbcUrl, "ts_with_timezone", new Properties)
val rows = dfRead.collect()
val types = rows(0).toSeq.map(x => x.getClass.toString)
assert(types(1).equals("class java.sql.Timestamp"))
assert(types(2).equals("class java.sql.Timestamp"))
}

test("SPARK-22291: Conversion error when transforming array types of " +
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.jdbc

import java.sql.Timestamp
import java.util.Properties

import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.functions.{col, lit, rand, when}

trait UpsertTests {
self: DockerJDBCIntegrationSuite =>

import testImplicits._

def createTableOption: String
def upsertTestOptions: Map[String, String] = Map("createTableOptions" -> createTableOption)

test(s"Upsert existing table") { doTestUpsert(tableExists = true) }
test(s"Upsert non-existing table") { doTestUpsert(tableExists = false) }

Seq(
Seq("ts", "id", "v1", "v2"),
Seq("ts", "v1", "id", "v2"),
Seq("ts", "v1", "v2", "id"),
Seq("v2", "v1", "ts", "id")
).foreach { columns =>
test(s"Upsert with varying column order - ${columns.mkString(",")}") {
doTestUpsert(tableExists = true, Some(columns))
}
}

def doTestUpsert(tableExists: Boolean, project: Option[Seq[String]] = None): Unit = {
val df = Seq(
(1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, 1.234568), // row unchanged
(2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, 2.345678), // updates v1
(2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, 2.345680), // updates v1 and v2
(3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, 3.456789) // inserts new row
).toDF("id", "ts", "v1", "v2").repartition(1) // .repartition(10)

val table = if (tableExists) "upsert" else "new_upsert_table"
val options = upsertTestOptions ++ Map(
"numPartitions" -> "10",
"upsert" -> "true",
"upsertKeyColumns" -> "id, ts"
)
project.map(columns => df.select(columns.map(col): _*)).getOrElse(df)
.write.mode(SaveMode.Append).options(options).jdbc(jdbcUrl, table, new Properties)

val actual = spark.read.jdbc(jdbcUrl, table, new Properties).collect().toSet
val existing = if (tableExists) {
Set((1, Timestamp.valueOf("1996-01-01 01:23:45"), 1.234, 1.234567))
} else {
Set.empty
}
val upsertedRows = Set(
(1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, 1.234568),
(2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, 2.345678),
(2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, 2.345680),
(3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, 3.456789)
)
val expected = (existing ++ upsertedRows).map { case (id, ts, v1, v2) =>
Row(Integer.valueOf(id), ts, v1.doubleValue(), v2.doubleValue())
}
assert(actual === expected)
}

test(s"Upsert concurrency") {
// create a table with 100k rows
val init =
spark.range(100000)
.withColumn("ts", lit(Timestamp.valueOf("2023-06-07 12:34:56")))
.withColumn("v", rand())

// upsert 100 batches of 100 rows each
// run in 32 tasks
val patch =
spark.range(100)
.join(spark.range(100).select(($"id" * 1000).as("offset")))
.repartition(32)
.select(
($"id" + $"offset").as("id"),
lit(Timestamp.valueOf("2023-06-07 12:34:56")).as("ts"),
lit(-1.0).as("v")
)

spark.sparkContext.setJobDescription("init")
init
.write
.mode(SaveMode.Overwrite)
.option("createTableOptions", createTableOption)
.jdbc(jdbcUrl, "new_upsert_table", new Properties)

spark.sparkContext.setJobDescription("patch")
patch
.write
.mode(SaveMode.Append)
.option("upsert", value = true)
.option("upsertKeyColumns", "id, ts")
.options(upsertTestOptions)
.jdbc(jdbcUrl, "new_upsert_table", new Properties)

// check result table has 100*100 updated rows
val result = spark.read.jdbc(jdbcUrl, "new_upsert_table", new Properties)
.select($"id", when($"v" === -1.0, true).otherwise(false).as("patched"))
.groupBy($"patched")
.count()
.sort($"patched")
.as[(Boolean, Long)]
.collect()
assert(result === Seq((false, 90000), (true, 10000)))
}

test("Upsert with columns that require quotes") {}
test("Upsert with table name that requires quotes") {}
test("Upsert null values") {}

test("Write with unspecified mode with upsert") {}
test("Write with overwrite mode with upsert") {}
test("Write with error-if-exists mode with upsert") {}
test("Write with ignore mode with upsert") {}
}
13 changes: 13 additions & 0 deletions docs/sql-data-sources-jdbc.md
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,19 @@ logging into the data sources.
<td>write</td>
</tr>

<tr>
<td><code>upsert, upsertKeyColumns</code></td>
<td>
These options are JDBC writer related options. They describe how to
use UPSERT feature for different JDBC dialects. The upsert option is applicable only when <code>SaveMode.Append</code> is enabled.
Set <code>upsert</code> to <code>true</code> to enable upsert append mode. The database is queried for the primary index to detect
the upsert key columns that are used to identify rows for update. The upsert key columns can be
defined via the <code>upsertKeyColumns</code> as a comma-separated list of column names.
Be aware that if the input data set has duplicate rows, the upsert operation is
non-deterministic, it is documented at the [upsert(merge) wiki:](https://en.wikipedia.org/wiki/Merge_(SQL)).
</td>
</tr>

<tr>
<td><code>customSchema</code></td>
<td>(none)</td>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1471,36 +1471,40 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
messageParameters = Map("changes" -> changes.toString()))
}

private def tableDoesNotSupportError(cmd: String, table: Table): Throwable = {
private def tableDoesNotSupportError(cmd: String, tableName: String): Throwable = {
new AnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_1121",
messageParameters = Map(
"cmd" -> cmd,
"table" -> table.name))
"table" -> tableName))
}

def tableDoesNotSupportReadsError(table: Table): Throwable = {
tableDoesNotSupportError("reads", table)
tableDoesNotSupportError("reads", table.name)
}

def tableDoesNotSupportWritesError(table: Table): Throwable = {
tableDoesNotSupportError("writes", table)
tableDoesNotSupportError("writes", table.name)
}

def tableDoesNotSupportDeletesError(table: Table): Throwable = {
tableDoesNotSupportError("deletes", table)
tableDoesNotSupportError("deletes", table.name)
}

def tableDoesNotSupportTruncatesError(table: Table): Throwable = {
tableDoesNotSupportError("truncates", table)
tableDoesNotSupportError("truncates", table.name)
}

def tableDoesNotSupportUpsertsError(tableName: String): Throwable = {
tableDoesNotSupportError("upserts", tableName)
}

def tableDoesNotSupportPartitionManagementError(table: Table): Throwable = {
tableDoesNotSupportError("partition management", table)
tableDoesNotSupportError("partition management", table.name)
}

def tableDoesNotSupportAtomicPartitionManagementError(table: Table): Throwable = {
tableDoesNotSupportError("atomic partition management", table)
tableDoesNotSupportError("atomic partition management", table.name)
}

def tableIsNotRowLevelOperationTableError(table: Table): Throwable = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ class JDBCOptions(
val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems not necessary!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

val isCascadeTruncate: Option[Boolean] = parameters.get(JDBC_CASCADE_TRUNCATE).map(_.toBoolean)
// if to upsert the table in the JDBC database
val isUpsert = parameters.getOrElse(JDBC_UPSERT, "false").toBoolean
// the columns used to identify update and insert rows in upsert mode
val upsertKeyColumns = parameters.getOrElse(JDBC_UPSERT_KEY_COLUMNS, "").split(",").map(_.trim)

// the create table option , which can be table_options or partition_options.
// E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
// TODO: to reuse the existing partition parameters for those partition specific options
Expand Down Expand Up @@ -296,6 +301,8 @@ object JDBCOptions {
val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize")
val JDBC_TRUNCATE = newOption("truncate")
val JDBC_CASCADE_TRUNCATE = newOption("cascadeTruncate")
val JDBC_UPSERT = newOption("upsert")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

upserts or upsert ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd go for upsert, as in upsert mode.

val JDBC_UPSERT_KEY_COLUMNS = newOption("upsertKeyColumns")
val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions")
val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes")
val JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES = newOption("customSchema")
Expand Down
Loading