diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index f89d55b20e21..05827cb8d717 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -169,8 +169,23 @@ object JdbcUtils extends Logging { val dialect = JdbcDialects.get(url) df.schema.fields foreach { field => { val name = field.name - val typ: String = - dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( + val dbcolumntype = { + // check if user specified target database column type for the field. + if (field.metadata.contains("db.column.type")) { + val coltype: String = field.metadata.getString("db.column.type") + if (coltype != null) { + // remove spaces to avoid any invalid sql statements in the input. + // spaces are not required in the database type names. + Some(coltype.replaceAll("\\s", "")) + } else { + None + } + } else { + None + } + } + val typ: String = dbcolumntype. + orElse(dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition)).getOrElse( field.dataType match { case IntegerType => "INTEGER" case LongType => "BIGINT" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e23ee6693133..37a022783595 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql.jdbc import java.sql.DriverManager import java.util.Properties +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.functions.col import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -151,4 +153,68 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } + + test("Test write with user specified database types.") { + + val valueArray = Array[Row]( + Row.apply("dave", 1, "UsA", "expert in electric cars", BigDecimal(42.33456)), + Row.apply("mary", 2, "USA", "Building planes", BigDecimal(42.33456)), + Row.apply("kathy", 3, "France", null, BigDecimal(42.33456)), + Row.apply("Mike", 4, null, "video games", BigDecimal(42.33456)) + ) + + val invalidType = s""" clob(20k)) as select * from foo; drop table foo; + | (select * from foo where 1=1""".stripMargin + + val (varcharIgnoreMd: Metadata, clobMd: Metadata, decimalMd: Metadata, invalidMd: Metadata) = { + List("varchar_ignorecase(20)", "clob(20k)", "DECIMAL(31,2)", invalidType).map(dbcoltype => { + val metadataBuilder = new MetadataBuilder() + metadataBuilder.putString("db.column.type", dbcoltype) + metadataBuilder.build() + }) match { + case List(a, b, c, d) => (a, b, c, d) + } + } + + val schema = StructType( + StructField("name", StringType) :: + StructField("id", IntegerType) :: + StructField("country", StringType, true, varcharIgnoreMd) :: + StructField("description", StringType, true, clobMd) :: + StructField("expense", DecimalType(38, 18)) :: + Nil) + + val properties = new Properties() + val df = sqlContext.createDataFrame(sparkContext.parallelize(valueArray), schema) + assert(JdbcUtils.schemaString(df, url) == + s"""name TEXT , id INTEGER , country varchar_ignorecase(20) , + | description clob(20k) , expense DECIMAL(38,18) """.stripMargin.replaceAll("\n", "")) + df.write.jdbc(url, "TEST.USERDBTYPETEST", new Properties) + assert(2 == sqlContext.read.jdbc(url, + "(select * from TEST.USERDBTYPETEST where country='usa' )", properties).count) + assert(1 == sqlContext.read.jdbc(url, + "(select * from TEST.USERDBTYPETEST where description is null)", properties).count) + assert(1 == sqlContext.read.jdbc(url, + "(select * from TEST.USERDBTYPETEST where country is null)", properties).count) + + // test specifying a different decimal type for existing data frame. + val newDF = df.withColumn("expense", col("expense").as("expense", decimalMd)) + assert(JdbcUtils.schemaString(newDF, url) == + s"""name TEXT , id INTEGER , country varchar_ignorecase(20) , + | description clob(20k) , expense DECIMAL(31,2) """.stripMargin.replaceAll("\n", "")) + newDF.write.mode(SaveMode.Overwrite).jdbc(url, "TEST.USERDBTYPETEST", properties) + assert(4 == sqlContext.read.jdbc(url, "TEST.USERDBTYPETEST", properties).count) + assert(BigDecimal(sqlContext.read.jdbc(url, "TEST.USERDBTYPETEST", + new Properties).collect()(0).get(4).asInstanceOf[java.math.BigDecimal]) == BigDecimal(42.33)) + + // test invalid character in the user specified type + val invalidDf = df.withColumn("expense", col("expense").as("expense", invalidMd)) + assert(JdbcUtils.schemaString(invalidDf, url) == + s"""name TEXT , id INTEGER , country varchar_ignorecase(20) , description clob(20k) , + | expense clob(20k))asselect*fromfoo;droptablefoo; + |(select*fromfoowhere1=1 """.stripMargin.replaceAll("\n", "")) + intercept[org.h2.jdbc.JdbcSQLException] { + invalidDf.write.mode(SaveMode.Overwrite).jdbc(url, "TEST.USERDBTYPETEST", properties) + } + } }