Skip to content

Commit ce089e8

Browse files
dongjoon-hyuncmonkey
authored andcommitted
[SPARK-18123][SQL] Use db column names instead of RDD column ones during JDBC Writing
## What changes were proposed in this pull request? Apache Spark supports the following cases **by quoting RDD column names** while saving through JDBC. - Allow reserved keyword as a column name, e.g., 'order'. - Allow mixed-case colume names like the following, e.g., `[a: int, A: int]`. ``` scala scala> val df = sql("select 1 a, 1 A") df: org.apache.spark.sql.DataFrame = [a: int, A: int] ... scala> df.write.mode("overwrite").format("jdbc").options(option).save() scala> df.write.mode("append").format("jdbc").options(option).save() ``` This PR aims to use **database column names** instead of RDD column ones in order to support the following additionally. Note that this case succeeds with `MySQL`, but fails on `Postgres`/`Oracle` before. ``` scala val df1 = sql("select 1 a") val df2 = sql("select 1 A") ... df1.write.mode("overwrite").format("jdbc").options(option).save() df2.write.mode("append").format("jdbc").options(option).save() ``` ## How was this patch tested? Pass the Jenkins test with a new testcase. Author: Dongjoon Hyun <[email protected]> Author: gatorsmile <[email protected]> Closes apache#15664 from dongjoon-hyun/SPARK-18123.
1 parent 697858a commit ce089e8

File tree

3 files changed

+95
-25
lines changed

3 files changed

+95
-25
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
5757
val table = jdbcOptions.table
5858
val createTableOptions = jdbcOptions.createTableOptions
5959
val isTruncate = jdbcOptions.isTruncate
60+
val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis
6061

6162
val conn = JdbcUtils.createConnectionFactory(jdbcOptions)()
6263
try {
@@ -67,16 +68,18 @@ class JdbcRelationProvider extends CreatableRelationProvider
6768
if (isTruncate && isCascadingTruncateTable(url) == Some(false)) {
6869
// In this case, we should truncate table and then load.
6970
truncateTable(conn, table)
70-
saveTable(df, url, table, jdbcOptions)
71+
val tableSchema = JdbcUtils.getSchemaOption(conn, url, table)
72+
saveTable(df, url, table, tableSchema, isCaseSensitive, jdbcOptions)
7173
} else {
7274
// Otherwise, do not truncate the table, instead drop and recreate it
7375
dropTable(conn, table)
7476
createTable(df.schema, url, table, createTableOptions, conn)
75-
saveTable(df, url, table, jdbcOptions)
77+
saveTable(df, url, table, Some(df.schema), isCaseSensitive, jdbcOptions)
7678
}
7779

7880
case SaveMode.Append =>
79-
saveTable(df, url, table, jdbcOptions)
81+
val tableSchema = JdbcUtils.getSchemaOption(conn, url, table)
82+
saveTable(df, url, table, tableSchema, isCaseSensitive, jdbcOptions)
8083

8184
case SaveMode.ErrorIfExists =>
8285
throw new AnalysisException(
@@ -89,7 +92,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
8992
}
9093
} else {
9194
createTable(df.schema, url, table, createTableOptions, conn)
92-
saveTable(df, url, table, jdbcOptions)
95+
saveTable(df, url, table, Some(df.schema), isCaseSensitive, jdbcOptions)
9396
}
9497
} finally {
9598
conn.close()

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import scala.util.control.NonFatal
2626
import org.apache.spark.TaskContext
2727
import org.apache.spark.executor.InputMetrics
2828
import org.apache.spark.internal.Logging
29-
import org.apache.spark.sql.{DataFrame, Row}
29+
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
3030
import org.apache.spark.sql.catalyst.InternalRow
3131
import org.apache.spark.sql.catalyst.encoders.RowEncoder
3232
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
@@ -108,14 +108,36 @@ object JdbcUtils extends Logging {
108108
}
109109

110110
/**
111-
* Returns a PreparedStatement that inserts a row into table via conn.
111+
* Returns an Insert SQL statement for inserting a row into the target table via JDBC conn.
112112
*/
113-
def insertStatement(conn: Connection, table: String, rddSchema: StructType, dialect: JdbcDialect)
114-
: PreparedStatement = {
115-
val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
113+
def getInsertStatement(
114+
table: String,
115+
rddSchema: StructType,
116+
tableSchema: Option[StructType],
117+
isCaseSensitive: Boolean,
118+
dialect: JdbcDialect): String = {
119+
val columns = if (tableSchema.isEmpty) {
120+
rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
121+
} else {
122+
val columnNameEquality = if (isCaseSensitive) {
123+
org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
124+
} else {
125+
org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
126+
}
127+
// The generated insert statement needs to follow rddSchema's column sequence and
128+
// tableSchema's column names. When appending data into some case-sensitive DBMSs like
129+
// PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of
130+
// RDD column names for user convenience.
131+
val tableColumnNames = tableSchema.get.fieldNames
132+
rddSchema.fields.map { col =>
133+
val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
134+
throw new AnalysisException(s"""Column "${col.name}" not found in schema $tableSchema""")
135+
}
136+
dialect.quoteIdentifier(normalizedName)
137+
}.mkString(",")
138+
}
116139
val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
117-
val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders)"
118-
conn.prepareStatement(sql)
140+
s"INSERT INTO $table ($columns) VALUES ($placeholders)"
119141
}
120142

121143
/**
@@ -210,6 +232,26 @@ object JdbcUtils extends Logging {
210232
answer
211233
}
212234

235+
/**
236+
* Returns the schema if the table already exists in the JDBC database.
237+
*/
238+
def getSchemaOption(conn: Connection, url: String, table: String): Option[StructType] = {
239+
val dialect = JdbcDialects.get(url)
240+
241+
try {
242+
val statement = conn.prepareStatement(dialect.getSchemaQuery(table))
243+
try {
244+
Some(getSchema(statement.executeQuery(), dialect))
245+
} catch {
246+
case _: SQLException => None
247+
} finally {
248+
statement.close()
249+
}
250+
} catch {
251+
case _: SQLException => None
252+
}
253+
}
254+
213255
/**
214256
* Takes a [[ResultSet]] and returns its Catalyst schema.
215257
*
@@ -531,7 +573,7 @@ object JdbcUtils extends Logging {
531573
table: String,
532574
iterator: Iterator[Row],
533575
rddSchema: StructType,
534-
nullTypes: Array[Int],
576+
insertStmt: String,
535577
batchSize: Int,
536578
dialect: JdbcDialect,
537579
isolationLevel: Int): Iterator[Byte] = {
@@ -568,9 +610,9 @@ object JdbcUtils extends Logging {
568610
conn.setAutoCommit(false) // Everything in the same db transaction.
569611
conn.setTransactionIsolation(finalIsolationLevel)
570612
}
571-
val stmt = insertStatement(conn, table, rddSchema, dialect)
572-
val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
573-
.map(makeSetter(conn, dialect, _)).toArray
613+
val stmt = conn.prepareStatement(insertStmt)
614+
val setters = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
615+
val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
574616
val numFields = rddSchema.fields.length
575617

576618
try {
@@ -657,16 +699,16 @@ object JdbcUtils extends Logging {
657699
df: DataFrame,
658700
url: String,
659701
table: String,
702+
tableSchema: Option[StructType],
703+
isCaseSensitive: Boolean,
660704
options: JDBCOptions): Unit = {
661705
val dialect = JdbcDialects.get(url)
662-
val nullTypes: Array[Int] = df.schema.fields.map { field =>
663-
getJdbcType(field.dataType, dialect).jdbcNullType
664-
}
665-
666706
val rddSchema = df.schema
667707
val getConnection: () => Connection = createConnectionFactory(options)
668708
val batchSize = options.batchSize
669709
val isolationLevel = options.isolationLevel
710+
711+
val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
670712
val repartitionedDF = options.numPartitions match {
671713
case Some(n) if n <= 0 => throw new IllegalArgumentException(
672714
s"Invalid value `$n` for parameter `${JDBCOptions.JDBC_NUM_PARTITIONS}` in table writing " +
@@ -675,7 +717,7 @@ object JdbcUtils extends Logging {
675717
case _ => df
676718
}
677719
repartitionedDF.foreachPartition(iterator => savePartition(
678-
getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel)
720+
getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
679721
)
680722
}
681723

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ import scala.collection.JavaConverters.propertiesAsScalaMapConverter
2424

2525
import org.scalatest.BeforeAndAfter
2626

27-
import org.apache.spark.SparkException
28-
import org.apache.spark.sql.{Row, SaveMode}
27+
import org.apache.spark.sql.{AnalysisException, Row, SaveMode}
2928
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
29+
import org.apache.spark.sql.internal.SQLConf
3030
import org.apache.spark.sql.test.SharedSQLContext
3131
import org.apache.spark.sql.types._
3232
import org.apache.spark.util.Utils
@@ -96,6 +96,10 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
9696
StructField("id", IntegerType) ::
9797
StructField("seq", IntegerType) :: Nil)
9898

99+
private lazy val schema4 = StructType(
100+
StructField("NAME", StringType) ::
101+
StructField("ID", IntegerType) :: Nil)
102+
99103
test("Basic CREATE") {
100104
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
101105

@@ -165,6 +169,26 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
165169
assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).collect()(0).length)
166170
}
167171

172+
test("SPARK-18123 Append with column names with different cases") {
173+
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
174+
val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema4)
175+
176+
df.write.jdbc(url, "TEST.APPENDTEST", new Properties())
177+
178+
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
179+
val m = intercept[AnalysisException] {
180+
df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties())
181+
}.getMessage
182+
assert(m.contains("Column \"NAME\" not found"))
183+
}
184+
185+
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
186+
df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties())
187+
assert(3 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).count())
188+
assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).collect()(0).length)
189+
}
190+
}
191+
168192
test("Truncate") {
169193
JdbcDialects.registerDialect(testH2Dialect)
170194
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
@@ -177,7 +201,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
177201
assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count())
178202
assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)
179203

180-
val m = intercept[SparkException] {
204+
val m = intercept[AnalysisException] {
181205
df3.write.mode(SaveMode.Overwrite).option("truncate", true)
182206
.jdbc(url1, "TEST.TRUNCATETEST", properties)
183207
}.getMessage
@@ -203,9 +227,10 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
203227
val df2 = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3)
204228

205229
df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties())
206-
intercept[org.apache.spark.SparkException] {
230+
val m = intercept[AnalysisException] {
207231
df2.write.mode(SaveMode.Append).jdbc(url, "TEST.INCOMPATIBLETEST", new Properties())
208-
}
232+
}.getMessage
233+
assert(m.contains("Column \"seq\" not found"))
209234
}
210235

211236
test("INSERT to JDBC Datasource") {

0 commit comments

Comments
 (0)