Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,23 @@ object JdbcUtils extends Logging {
val dialect = JdbcDialects.get(url)
df.schema.fields foreach { field => {
val name = field.name
val typ: String =
dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse(
val dbcolumntype = {
// check if user specified target database column type for the field.
if (field.metadata.contains("db.column.type")) {
val coltype: String = field.metadata.getString("db.column.type")
if (coltype != null) {
// remove spaces to avoid any invalid sql statements in the input.
// spaces are not required in the database type names.
Some(coltype.replaceAll("\\s", ""))
} else {
None
}
} else {
None
}
}
val typ: String = dbcolumntype.
orElse(dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition)).getOrElse(
field.dataType match {
case IntegerType => "INTEGER"
case LongType => "BIGINT"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ package org.apache.spark.sql.jdbc
import java.sql.DriverManager
import java.util.Properties

import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -151,4 +153,68 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}

test("Test write with user specified database types.") {

val valueArray = Array[Row](
Row.apply("dave", 1, "UsA", "expert in electric cars", BigDecimal(42.33456)),
Row.apply("mary", 2, "USA", "Building planes", BigDecimal(42.33456)),
Row.apply("kathy", 3, "France", null, BigDecimal(42.33456)),
Row.apply("Mike", 4, null, "video games", BigDecimal(42.33456))
)

val invalidType = s""" clob(20k)) as select * from foo; drop table foo;
| (select * from foo where 1=1""".stripMargin

val (varcharIgnoreMd: Metadata, clobMd: Metadata, decimalMd: Metadata, invalidMd: Metadata) = {
List("varchar_ignorecase(20)", "clob(20k)", "DECIMAL(31,2)", invalidType).map(dbcoltype => {
val metadataBuilder = new MetadataBuilder()
metadataBuilder.putString("db.column.type", dbcoltype)
metadataBuilder.build()
}) match {
case List(a, b, c, d) => (a, b, c, d)
}
}

val schema = StructType(
StructField("name", StringType) ::
StructField("id", IntegerType) ::
StructField("country", StringType, true, varcharIgnoreMd) ::
StructField("description", StringType, true, clobMd) ::
StructField("expense", DecimalType(38, 18)) ::
Nil)

val properties = new Properties()
val df = sqlContext.createDataFrame(sparkContext.parallelize(valueArray), schema)
assert(JdbcUtils.schemaString(df, url) ==
s"""name TEXT , id INTEGER , country varchar_ignorecase(20) ,
| description clob(20k) , expense DECIMAL(38,18) """.stripMargin.replaceAll("\n", ""))
df.write.jdbc(url, "TEST.USERDBTYPETEST", new Properties)
assert(2 == sqlContext.read.jdbc(url,
"(select * from TEST.USERDBTYPETEST where country='usa' )", properties).count)
assert(1 == sqlContext.read.jdbc(url,
"(select * from TEST.USERDBTYPETEST where description is null)", properties).count)
assert(1 == sqlContext.read.jdbc(url,
"(select * from TEST.USERDBTYPETEST where country is null)", properties).count)

// test specifying a different decimal type for existing data frame.
val newDF = df.withColumn("expense", col("expense").as("expense", decimalMd))
assert(JdbcUtils.schemaString(newDF, url) ==
s"""name TEXT , id INTEGER , country varchar_ignorecase(20) ,
| description clob(20k) , expense DECIMAL(31,2) """.stripMargin.replaceAll("\n", ""))
newDF.write.mode(SaveMode.Overwrite).jdbc(url, "TEST.USERDBTYPETEST", properties)
assert(4 == sqlContext.read.jdbc(url, "TEST.USERDBTYPETEST", properties).count)
assert(BigDecimal(sqlContext.read.jdbc(url, "TEST.USERDBTYPETEST",
new Properties).collect()(0).get(4).asInstanceOf[java.math.BigDecimal]) == BigDecimal(42.33))

// test invalid character in the user specified type
val invalidDf = df.withColumn("expense", col("expense").as("expense", invalidMd))
assert(JdbcUtils.schemaString(invalidDf, url) ==
s"""name TEXT , id INTEGER , country varchar_ignorecase(20) , description clob(20k) ,
| expense clob(20k))asselect*fromfoo;droptablefoo;
|(select*fromfoowhere1=1 """.stripMargin.replaceAll("\n", ""))
intercept[org.h2.jdbc.JdbcSQLException] {
invalidDf.write.mode(SaveMode.Overwrite).jdbc(url, "TEST.USERDBTYPETEST", properties)
}
}
}