Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
val table = jdbcOptions.table
val createTableOptions = jdbcOptions.createTableOptions
val isTruncate = jdbcOptions.isTruncate
val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis

val conn = JdbcUtils.createConnectionFactory(jdbcOptions)()
try {
Expand All @@ -67,16 +68,18 @@ class JdbcRelationProvider extends CreatableRelationProvider
if (isTruncate && isCascadingTruncateTable(url) == Some(false)) {
// In this case, we should truncate table and then load.
truncateTable(conn, table)
saveTable(df, url, table, jdbcOptions)
val tableSchema = JdbcUtils.getSchemaOption(conn, url, table)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this into case statements.
Since JdbcUtils.tableExists is used, getSchemaOption can be skipped for the other SaveMode.

saveTable(df, url, table, tableSchema, isCaseSensitive, jdbcOptions)
} else {
// Otherwise, do not truncate the table, instead drop and recreate it
dropTable(conn, table)
createTable(df.schema, url, table, createTableOptions, conn)
saveTable(df, url, table, jdbcOptions)
saveTable(df, url, table, Some(df.schema), isCaseSensitive, jdbcOptions)
}

case SaveMode.Append =>
saveTable(df, url, table, jdbcOptions)
val tableSchema = JdbcUtils.getSchemaOption(conn, url, table)
saveTable(df, url, table, tableSchema, isCaseSensitive, jdbcOptions)

case SaveMode.ErrorIfExists =>
throw new AnalysisException(
Expand All @@ -89,7 +92,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
}
} else {
createTable(df.schema, url, table, createTableOptions, conn)
saveTable(df, url, table, jdbcOptions)
saveTable(df, url, table, Some(df.schema), isCaseSensitive, jdbcOptions)
}
} finally {
conn.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.util.control.NonFatal
import org.apache.spark.TaskContext
import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
Expand Down Expand Up @@ -108,14 +108,36 @@ object JdbcUtils extends Logging {
}

/**
* Returns a PreparedStatement that inserts a row into table via conn.
* Returns an Insert SQL statement for inserting a row into the target table via JDBC conn.
*/
def insertStatement(conn: Connection, table: String, rddSchema: StructType, dialect: JdbcDialect)
: PreparedStatement = {
val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
def getInsertStatement(
table: String,
rddSchema: StructType,
tableSchema: Option[StructType],
isCaseSensitive: Boolean,
dialect: JdbcDialect): String = {
val columns = if (tableSchema.isEmpty) {
rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The legacy behavior is used when tableSchema is None.

} else {
val columnNameEquality = if (isCaseSensitive) {
org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
} else {
org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
}
// The generated insert statement needs to follow rddSchema's column sequence and
// tableSchema's column names. When appending data into some case-sensitive DBMSs like
// PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of
// RDD column names for user convenience.
val tableColumnNames = tableSchema.get.fieldNames
rddSchema.fields.map { col =>
val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
throw new AnalysisException(s"""Column "${col.name}" not found in schema $tableSchema""")
}
dialect.quoteIdentifier(normalizedName)
}.mkString(",")
}
val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders)"
conn.prepareStatement(sql)
s"INSERT INTO $table ($columns) VALUES ($placeholders)"
}

/**
Expand Down Expand Up @@ -210,6 +232,26 @@ object JdbcUtils extends Logging {
answer
}

/**
* Returns the schema if the table already exists in the JDBC database.
*/
def getSchemaOption(conn: Connection, url: String, table: String): Option[StructType] = {
val dialect = JdbcDialects.get(url)

try {
val statement = conn.prepareStatement(dialect.getSchemaQuery(table))
Copy link
Member

@gatorsmile gatorsmile Dec 28, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should this be included in the try block?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It's due to PreparedStatement prepareStatement(String sql) throws SQLException. I want to make it sure to return None instead of raising SQLException at this line.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better not to hide that here? If then, I'll remove that try.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JDBC data sources might detect whether the table exists when preparing the SQL statement. Thus, please keep it. Thanks

try {
Some(getSchema(statement.executeQuery(), dialect))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getSchema will throw an exception when the schema contains an unsupported type. Now we use it to check if the table exists. Does it change current behavior? E.g., the insertion working before now fails.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For unsupported types, it will throw an SQLException, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then, we can keep the existing way. See master...gatorsmile:pr-15664Changed1

} catch {
case _: SQLException => None
} finally {
statement.close()
}
} catch {
case _: SQLException => None
}
}

/**
* Takes a [[ResultSet]] and returns its Catalyst schema.
*
Expand Down Expand Up @@ -531,7 +573,7 @@ object JdbcUtils extends Logging {
table: String,
iterator: Iterator[Row],
rddSchema: StructType,
nullTypes: Array[Int],
insertStmt: String,
batchSize: Int,
dialect: JdbcDialect,
isolationLevel: Int): Iterator[Byte] = {
Expand Down Expand Up @@ -568,9 +610,9 @@ object JdbcUtils extends Logging {
conn.setAutoCommit(false) // Everything in the same db transaction.
conn.setTransactionIsolation(finalIsolationLevel)
}
val stmt = insertStatement(conn, table, rddSchema, dialect)
val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
.map(makeSetter(conn, dialect, _)).toArray
val stmt = conn.prepareStatement(insertStmt)
val setters = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
val numFields = rddSchema.fields.length

try {
Expand Down Expand Up @@ -657,16 +699,16 @@ object JdbcUtils extends Logging {
df: DataFrame,
url: String,
table: String,
tableSchema: Option[StructType],
isCaseSensitive: Boolean,
options: JDBCOptions): Unit = {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
getJdbcType(field.dataType, dialect).jdbcNullType
}

val rddSchema = df.schema
val getConnection: () => Connection = createConnectionFactory(options)
val batchSize = options.batchSize
val isolationLevel = options.isolationLevel

val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
val repartitionedDF = options.numPartitions match {
case Some(n) if n <= 0 => throw new IllegalArgumentException(
s"Invalid value `$n` for parameter `${JDBCOptions.JDBC_NUM_PARTITIONS}` in table writing " +
Expand All @@ -675,7 +717,7 @@ object JdbcUtils extends Logging {
case _ => df
}
repartitionedDF.foreachPartition(iterator => savePartition(
getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel)
getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ import scala.collection.JavaConverters.propertiesAsScalaMapConverter

import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkException
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.{AnalysisException, Row, SaveMode}
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -96,6 +96,10 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
StructField("id", IntegerType) ::
StructField("seq", IntegerType) :: Nil)

private lazy val schema4 = StructType(
StructField("NAME", StringType) ::
StructField("ID", IntegerType) :: Nil)

test("Basic CREATE") {
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)

Expand Down Expand Up @@ -165,6 +169,26 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).collect()(0).length)
}

test("SPARK-18123 Append with column names with different cases") {
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema4)

df.write.jdbc(url, "TEST.APPENDTEST", new Properties())

withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
val m = intercept[AnalysisException] {
df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties())
}.getMessage
assert(m.contains("Column \"NAME\" not found"))
}

withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties())
assert(3 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).count())
assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).collect()(0).length)
}
}

test("Truncate") {
JdbcDialects.registerDialect(testH2Dialect)
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
Expand All @@ -177,7 +201,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count())
assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)

val m = intercept[SparkException] {
val m = intercept[AnalysisException] {
df3.write.mode(SaveMode.Overwrite).option("truncate", true)
.jdbc(url1, "TEST.TRUNCATETEST", properties)
}.getMessage
Expand All @@ -203,9 +227,10 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
val df2 = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3)

df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties())
intercept[org.apache.spark.SparkException] {
val m = intercept[AnalysisException] {
df2.write.mode(SaveMode.Append).jdbc(url, "TEST.INCOMPATIBLETEST", new Properties())
}
}.getMessage
assert(m.contains("Column \"seq\" not found"))
}

test("INSERT to JDBC Datasource") {
Expand Down